#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Aug 16 16:46:10 2023

@author: d
"""

import re
import string
import numpy as np
import pandas as pd

import seaborn as sn
import matplotlib as mpl
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.dates as mdates
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MaxNLocator

from itertools import product
import statsmodels.formula.api as smf
from sklearn.metrics import mean_squared_error, r2_score

# this is to fix a memory issue in Spyder
mpl.interactive(False)

sn.set_theme(style = "white", context = "paper", font_scale = 2.3,
             rc = {"xtick.bottom" : True, "ytick.left" : True,
                   "xtick.direction": "in", "ytick.direction": "in",
                   "figure.dpi":300, "savefig.dpi":300})

run_combinations = False

#%%
inside = pd.read_csv("inside_cleaned.csv", parse_dates = True)
outside = pd.read_csv("outside_cleaned.csv", parse_dates = True)
inside["datetime"] = pd.to_datetime(inside["datetime"])
outside["datetime"] = pd.to_datetime(outside["datetime"])

print("Inside over 10 ppm: ", 
      len(inside) - len(inside[(inside["CH4"] < 10) & (inside["CH4"].shift(1) < 10)]), 
      " / ", len(inside))
print("Outside over 10 ppm: ", 
      len(outside) - len(outside[(outside["CH4"] < 10) & (outside["CH4"].shift(-1) < 10)]), 
      " / ", len(outside))

inside = inside[(inside["CH4"] < 10) & (inside["CH4"].shift(1) < 10)].copy()
outside = outside[(outside["CH4"] < 10) & (outside["CH4"].shift(1) < 10)].copy()
both = pd.concat([outside, inside])

print("proportion outside below 2.5", len(outside[outside["CH4"] < 2.5])/len(outside))
print("proportion inside below 2.5", len(inside[inside["CH4"] < 2.5])/len(inside))

print("# outside below 2.3", len(outside[outside["CH4"] < 2.3]))
print("# inside below 2.3", len(inside[inside["CH4"] < 2.3]))

#%%
#
# correlations
#

def plot_corrs(df, title):
    variables = df[["tgs2600", "tgs2611", "temp_c", "rh", "CH4", "H2O", "time"]]
    variables = variables.rename(columns = {"temp_c": "T",
                                            "rh": "RH",
                                            "tgs2600": "TGS2600",
                                            "tgs2611": "TGS2611",
                                            "CH4": "$\\mathrm{CH_4}$",
                                            "H2O": "$\\mathrm{H_2O}$",
                                            "time": "Time"})
    
    norm = mpl.colors.Normalize(vmin=-1.5, vmax=1.5)
    
    def corr(x, y, bg, hue = None, **kw):
        r = x.corr(y, "pearson")
        r_str = "{:.2f}".format(r)
        ax = plt.gca()
        
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        
        for spine in ax.spines.values():
            spine.set_visible(False)
            
        if bg:
            ax.set_facecolor(cm.bwr(norm(r)))
            
        ax.annotate(r_str, [.5, .5,],  xycoords="axes fraction",
                    ha='center', va='center', fontsize = 30)
    
    def filtered_hist(x, **kwargs):
        if x.name != "Time":
            sn.histplot(x, **kwargs)
    
    pair = sn.PairGrid(variables, diag_sharey = False)
    pair.map_lower(sn.scatterplot, color = "black", edgecolor = None, 
                   size = 5)
    pair.map_diag(filtered_hist, color = "green", alpha = 1, element = "step")
    pair.map_upper(corr, bg = True)
    
    pair.set(yticklabels = [], xticklabels = [])
    
    pair.fig.subplots_adjust(top = 0.95)
    pair.fig.suptitle(title)
    return pair

sn.set_theme(style = "white", context = "paper", font_scale = 3,
             rc = {"xtick.bottom" : True, "ytick.left" : True,
                   "xtick.direction": "in", "ytick.direction": "in",
                   "figure.dpi":300, "savefig.dpi":300})

in_corrs = plot_corrs(inside, "B: Inside")
out_corrs = plot_corrs(outside, "A: Outside")
in_corrs.savefig('in.png', dpi=300)
plt.close(in_corrs.fig)
out_corrs.savefig('out.png', dpi=300)
plt.close(out_corrs.fig)

f, ax = plt.subplots(2, 1, figsize=(20, 20))

ax[1].imshow(mpimg.imread('in.png'))
ax[0].imshow(mpimg.imread('out.png'))
ax[0].set_axis_off()
ax[1].set_axis_off()
plt.tight_layout()
plt.show()

sn.set_theme(style = "white", context = "paper", font_scale = 2.3,
             rc = {"xtick.bottom" : True, "ytick.left" : True,
                   "xtick.direction": "in", "ytick.direction": "in",
                   "figure.dpi":300, "savefig.dpi":300})


#%%
#
# time series
#

ratios = [outside["datetime"].iat[-1] - outside["datetime"].iat[0], 
          inside["datetime"].iat[-1] - inside["datetime"].iat[0]]
ratios = [r/max(ratios) for r in ratios]

f, axes = plt.subplots(4, 2, sharex = "col", sharey = "row",
                       gridspec_kw={'width_ratios': ratios}, figsize = (20, 16))
plt.subplots_adjust(wspace = 0.05, hspace = 0.05)

def plot_ts(which, name, color, ax, ticks, data = both, label = True):
    s = sn.scatterplot(data = data, x = "datetime", y = which, color = color, ax = ax,
                       s = 1.5, edgecolor = "none", legend = False)
    s.yaxis.label.set_color(color)
    s.set_ylabel(name)
    if not label:
        s.set_ylabel(None)
        s.set_yticklabels([])
    s.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins = ticks, min_n_ticks = ticks))
    return s

rhs = [axes[1, 0].twinx(), axes[1, 1].twinx()]
tgs2600s = [axes[3, 0].twinx(), axes[3, 1].twinx()]

for d, i in [[outside, 0], [inside, 1]]:
    plot_ts("temp_c", "Temperature (°C)", "black", axes[0, i], 4, data = d)
    plot_ts("rh", "RH (%)", "blue", rhs[i], 4, data = d, label = (i != 0))
    plot_ts("CH4", "$\\mathrm{CH_4}$ (ppm)", "black", axes[2, i], 4, data = d)
    plot_ts("H2O", "$\\mathrm{H_2O}$ (%v)", "black", axes[1, i], 4, data = d)
    plot_ts("tgs2600", "TGS2600 (kΩ)", "red", tgs2600s[i], 4, data = d, label = (i != 0))
    plot_ts("tgs2611", "TGS2611-E00 (kΩ)", "black", axes[3, i], 4, data = d)
    
    axes[3, i].xaxis.set_minor_locator(mdates.DayLocator())
    axes[3, i].xaxis.set_major_locator(mdates.DayLocator(bymonthday = 1))
    axes[3, i].xaxis.set_major_formatter(mdates.DateFormatter("%-d %b"))
    axes[3, i].set_xlabel("")

ylim = [min([a.get_ylim()[0] for a in rhs]),
        max([a.get_ylim()[1] for a in rhs])]
rhs[0].set_ylim(ylim)
rhs[1].set_ylim(ylim)

ylim = [min([a.get_ylim()[0] for a in tgs2600s]),
        max([a.get_ylim()[1] for a in tgs2600s])]
tgs2600s[0].set_ylim(ylim)
tgs2600s[1].set_ylim(ylim)

axes[0, 0].set_title("Outside data")
axes[0, 1].set_title("Inside data")

for i in [0, 1]:
    for j in range(4):
        label = string.ascii_uppercase[j] + str(i + 1)
        axes[j, i].text(0.025, 0.875, label, transform = axes[j, i].transAxes,
                        bbox = {"facecolor": "white", "alpha": 0.75,
                                "pad": 2})
plt.show()

#%%
#
# show diurnal patterns
#

def plot_ts(which, name, color, ax, ticks, data = both, label = True):
    s = sn.scatterplot(data = data, x = "datetime", y = which, color = color, ax = ax,
                       s = 5, edgecolor = "none", legend = False)
    s.yaxis.label.set_color(color)
    s.set_ylabel(name)
    if not label:
        s.set_ylabel(None)
        s.set_yticklabels([])
    s.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins = ticks, min_n_ticks = ticks))
    return s

f, axes = plt.subplots(4, 2, sharex = "col", sharey = False, figsize = (20, 20))
plt.subplots_adjust(wspace = 0.3, hspace = 0.05)

rhs = [axes[1, 0].twinx(), axes[1, 1].twinx()]
tgs2600s = [axes[3, 0].twinx(), axes[3, 1].twinx()]

for d, i in [[outside[outside["datetime"].between("9 aug 2022", "16 aug 2022")], 0], 
             [inside[inside["datetime"].between("1 mar 2023", "8 mar 2023")], 1]]:
    plot_ts("temp_c", "Temperature (°C)", "black", axes[0, i], 4, data = d)
    plot_ts("rh", "RH (%)", "blue", rhs[i], 4, data = d)
    plot_ts("CH4", "$\\mathrm{CH_4}$ (ppm)", "black", axes[2, i], 4, data = d)
    plot_ts("H2O", "$\\mathrm{H_2O}$ (%v)", "black", axes[1, i], 4, data = d)
    plot_ts("tgs2600", "TGS2600 (kΩ)", "red", tgs2600s[i], 4, data = d)
    plot_ts("tgs2611", "TGS2611-E00 (kΩ)", "black", axes[3, i], 4, data = d)
    
    axes[3, i].xaxis.set_minor_locator(mdates.DayLocator())
    axes[3, i].xaxis.set_major_locator(mdates.DayLocator(interval = 2))
    axes[3, i].xaxis.set_major_formatter(mdates.DateFormatter("%-d %b"))
    axes[3, i].set_xlabel("")

axes[0, 0].set_title("Outside data")
axes[0, 1].set_title("Inside data")

for i in [0, 1]:
    for j in range(4):
        label = string.ascii_uppercase[j] + str(i + 1)
        axes[j, i].text(0.025, 0.875, label, transform = axes[j, i].transAxes,
                        bbox = {"facecolor": "white", "alpha": 0.75,
                                "pad": 2})

plt.show()

#%%
def background_fit(formula, lh_trans, df, plot = True, eval_df = None,
                   target = "tgs2611", color = "H2O"):
    df = df.copy()
    low = df[df["CH4"] < 2.3].copy()
    if eval_df is not None:
        low_eval = eval_df[eval_df["CH4"] < 2.3].copy()
    else:
        low_eval = low.copy()
   
    mod = smf.ols(formula = formula, data = low).fit()
    df.loc[:,"background"] = lh_trans(mod.predict(df))
    low_eval.loc[:, "background"] = lh_trans(mod.predict(low_eval))
    
    if(plot):
        print(mod.summary())
    
    if(plot):
        plt.figure(figsize = (8, 5))
        sn.scatterplot(data = low_eval, x = target, y = "background", s = 20,
                       edgecolor = "none", hue = color, palette = "viridis")
        plt.gca().plot([low_eval[target].min(), low_eval[target].max()], 
                       [low_eval[target].min(), low_eval[target].max()], 
                       color='red',  linewidth=1.0)
        plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0,
                   title = color, frameon = False)
        
        plt.title(formula + "\nCH4 < 2.3")
        plt.show()
        
    low_eval["error"] = low_eval["background"] - low_eval[target]
    
    if(plot):
        plt.figure(figsize = (8, 5))
        sn.scatterplot(data = low_eval, x = "elapsed_time", y = "error", s = 20,
                       edgecolor = "none", hue = "temp_c", palette = "viridis")
        plt.gca().plot([low_eval["elapsed_time"].min(), low_eval["elapsed_time"].max()], [0, 0], 
                       color='red',  linewidth=1.0)
        
        plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0,
                   title = "CH4", frameon = False)
        
        plt.title("error" + "\nCH4 < 2.3")
        plt.show()
    
    df.loc[:,"ratio"] = df[target]/df["background"]
        
    return (df, r2_score(low_eval[target], low_eval["background"]), 
            mean_squared_error(low_eval[target], low_eval["background"], squared = False),
            mod)

#%%
#
# Fit the background. Try different combinations.
#
def do_fit(df, df_name, formula, target, log_xfm = False):
    if log_xfm:
        formula = "np.log(" + target + ")~" + formula
        trans = np.exp
    else:
        formula = target + "~" + formula
        trans = lambda x:x

    _, r2, rmse, mod = background_fit(formula, trans, df, plot = False, target = target)
    return {"formula": formula,
            "r2": r2,
            "rmse": rmse,
            "f": mod.fvalue,
            "dataset": df_name}

if run_combinations:
    inside["dummy_tgs2611"] = inside["tgs2611"].values[::-1]
    outside["dummy_tgs2611"] = outside["tgs2611"].values[::-1]
    both["dummy_tgs2611"] = both["tgs2611"].values[::-1]
    
    products = product(["H2O", "np.log(H2O)", ""],
                       ["temp_c", "np.log(temp_c)", ""],
                       ["elapsed_time", "np.log(elapsed_time + 0.0001)", ""],
                       ["tgs2600", "np.log(tgs2600)", ""])
    
    formulae = [re.sub("\++", "+", "+".join(p)).strip("+") for p in products][:-1]
    
    results = []
    target = "tgs2611"
    
    for f in formulae:
        for df, df_name in [[inside, "inside"],
                            [outside, "outside"],
                            [both, "both"]]:
            results.append(do_fit(df, df_name, f, target, True))
            results.append(do_fit(df, df_name, f, target, False))
    
    results = pd.DataFrame(results).sort_values(by = "rmse", ascending = True)
    results["tgs2600"] = results["formula"].str.contains("tgs2600")
    
    no_rh = results[~results["formula"].str.contains("rh")]
    no_h2o = results[~results["formula"].str.contains("H2O")]
    
    best_backgrounds = pd.DataFrame()
    for d in results["dataset"].unique():
        best_backgrounds = pd.concat([best_backgrounds,
                                      no_rh[(no_rh["dataset"] == d) & 
                                            (no_rh["tgs2600"] == True)].head(3),
                                      no_rh[(no_rh["dataset"] == d) & 
                                            (no_rh["tgs2600"] == False)].head(3)])

#%%
#
# fit background by time bins
#

def time_bg_fit(bins, formula, df, target = "tgs2611", plot = False):
    df = df.copy()
    df["time_bin"] = pd.cut(df.index, bins)
    params = {}
    
    for t in df["time_bin"].unique():
        data, _, _, mod = background_fit(formula, np.exp, df[df["time_bin"] == t], plot = plot,
                                         target = target)
        df.loc[df["time_bin"] == t, "background"] = data["background"]
        #df.loc[df["time_bin"] == t, "background"] = np.exp(mod.predict(df[df["time_bin"] == t]))
        params[t] = mod.params
    
    df.loc[:, "ratio"] = df[target]/df["background"]
        
    return (df, r2_score(df[df["CH4"] < 2.3][target], 
                         df[df["CH4"] < 2.3]["background"]), 
            mean_squared_error(df[df["CH4"] < 2.3][target], 
                               df[df["CH4"] < 2.3]["background"], squared = False),
            params)

#%%
#
# Fit the different approaches
#

results = pd.DataFrame()

for df, name in [[inside, "Inside"], [outside, "Outside"]]:
    formula = "np.log(tgs2611)~np.log(H2O)+temp_c+np.log(elapsed_time + 0.0001)"
    data, _, _, _ = background_fit(formula, np.exp, df, plot = False, target = "tgs2611",
                                color = "elapsed_time")
    data["dataset"] = name
    data["fit"] = "Equation 1, full fit"
    results = pd.concat([results, data[data["CH4"] < 2.3]])

    data, r2, rmse, params = time_bg_fit(10, formula, df)
    data["dataset"] = name
    data["fit"] = "Equation 1, piecewise fit"
    results = pd.concat([results, data[data["CH4"] < 2.3]])
    
    formula = "np.log(tgs2611)~np.log(H2O)+temp_c+np.log(elapsed_time + 0.0001)+np.log(tgs2600)"
    data, _, _, _ = background_fit(formula, np.exp, df, plot = False, target = "tgs2611",
                                color = "elapsed_time")
    data["dataset"] = name
    data["fit"] = "Equation 2, full fit"
    results = pd.concat([results, data[data["CH4"] < 2.3]])

    data, r2, rmse, params = time_bg_fit(10, formula, df)
    data["dataset"] = name
    data["fit"] = "Equation 2, piecewise fit"
    results = pd.concat([results, data[data["CH4"] < 2.3]])

#%%
#
# Show the different approaches for background fits
#
def plot_bg_fits(df, legend, labels):
    g = sn.relplot(data = df, 
                   x = "tgs2611", y = "background", hue = "elapsed_time",
                   palette = "viridis", edgecolor = "none",
                   col = "fit", row = "dataset", height = 8,
                   facet_kws = {"margin_titles": True, "sharex": "row", "sharey": "row",
                                "despine": False})
    
    g.set_titles(col_template = "{col_name}", row_template = "Dataset: {row_name}")
    g.set_axis_labels(x_var = "Actual TGS2611-E00 (kΩ)",
                      y_var = "Predicted TGS2611-E00 (kΩ)")
    
    top = [min([g.axes.flat[a].get_xlim()[0] for a in [0, 1]] + 
               [g.axes.flat[a].get_ylim()[0] for a in [0, 1]]),
           max([g.axes.flat[a].get_xlim()[1] for a in [0, 1]] +
               [g.axes.flat[a].get_ylim()[1] for a in [0, 1]])]
    bot = [min([g.axes.flat[a].get_xlim()[0] for a in [2, 3]] + 
               [g.axes.flat[a].get_ylim()[0] for a in [2, 3]]),
           max([g.axes.flat[a].get_xlim()[1] for a in [2, 3]] +
               [g.axes.flat[a].get_ylim()[1] for a in [2, 3]])]
    
    g.axes.flat[0].set_xlim(*top)
    g.axes.flat[2].set_xlim(*bot)
    g.axes.flat[0].set_ylim(*top)
    g.axes.flat[2].set_ylim(*bot)
    
    g.axes.flat[0].plot(top, top, 
                        color='red', linewidth=1.0)
    g.axes.flat[1].plot(top, top, 
                        color='red', linewidth=1.0)
    g.axes.flat[2].plot(bot, bot, 
                        color='red', linewidth=1.0)
    g.axes.flat[3].plot(bot, bot, 
                        color='red', linewidth=1.0)    
    
    plt.rcParams["legend.markerscale"] = 3
    sn.move_legend(g, "lower right")
    g._legend.set_title("Elapsed time\n(days)")
    for lh in g._legend.legend_handles: 
        lh._sizes = [200]
        
    if not legend:
        for lh in g._legend.legend_handles: 
            lh.set_alpha(0)
        for lt in g._legend.texts:
            lt.set_alpha(0)
        g._legend.set_title(None)
    
    for i in range(4):
        g.axes.flat[i].text(0.025, 0.925, labels[i], transform = g.axes.flat[i].transAxes)
    plt.subplots_adjust(wspace = 0)
    
    return g

g = plot_bg_fits(results[results["fit"].str.startswith("Equation 1")], True, ["A", "B", "C", "D"])
g.savefig('eq1.png', dpi=300)
plt.close(g.fig)

g = plot_bg_fits(results[results["fit"].str.startswith("Equation 2")], False, ["E", "F", "G", "H"])
g.savefig('eq2.png', dpi=300)
plt.close(g.fig)

f, ax = plt.subplots(2, 1, figsize=(40, 20))

ax[0].imshow(mpimg.imread('eq1.png'))
ax[1].imshow(mpimg.imread('eq2.png'))
ax[0].set_axis_off()
ax[1].set_axis_off()
plt.tight_layout()
plt.show()

plt.rcParams["legend.markerscale"] = 1

#%%
#
# Show the time-based change in coefficients
#

formula = "np.log(tgs2611)~np.log(H2O)+temp_c+np.log(elapsed_time + 0.0001)+np.log(tgs2600)"

names = {'Intercept': 'A: Intercept',
         'np.log(H2O)': 'B: log($H_2O$)',
         'temp_c': 'C: Temperature',
         'np.log(elapsed_time + 0.0001)': 'D: log(Time)',
         'np.log(tgs2600)': 'E: log(TGS2600)'}

result_df, _, _, params = time_bg_fit(20, formula, both)

params_time = pd.DataFrame()

for key, val in params.items():
    d = pd.DataFrame(val).reset_index().rename(columns = {"index": "parameter",
                                                          0: "value"})
    d["time"] = key.mid
    params_time = pd.concat([params_time, d])

params_time["parameter"] = params_time["parameter"].replace(names)

g = sn.FacetGrid(data = params_time, col = "parameter", col_wrap = 3,
                 height = 6, aspect = 1.5, sharex = True, sharey = False, despine = False)
g.set_titles(col_template = "{col_name}")
g.map(sn.scatterplot, "time", "value",
      color = "black", s = 75)

g.set_xlabels("Time bin")
g.set_ylabels("Coefficient")
g.set(xticklabels = [], xticks = [])

g.axes.flat[2].set_xlabel(None)
g.axes.flat[3].set_xlabel(None)

plt.show()

result_df["new_bin"] = result_df["time_bin"] != result_df["time_bin"].shift()
result_df["new_bin"] = result_df["new_bin"] | result_df["new_bin"].shift(-1)
new_bins = result_df[result_df["new_bin"]]

#%%
# fit methane using the ratio
def methane_fit(df, df_name, formula, target, log_xfm = False, stratified = False):
    df = df.copy()
    if log_xfm:
        formula = "np.log(" + target + ")~" + formula
        trans = np.exp
    else:
        formula = target + "~" + formula
        trans = lambda x:x
        
    if stratified:
        size = min(df.groupby("methane_bin").count()["datetime"].to_list())
        train_df = df.groupby("methane_bin").apply(lambda x: x.sample(size, 
                                                                      random_state = 0))
    else:
        train_df = df
    
    mod = smf.ols(formula = formula, data = train_df).fit()
    df.loc[:, "predicted CH4"] = trans(mod.predict(df))
    
    r2 = r2_score(df["CH4"], df["predicted CH4"]) 
    rmse = mean_squared_error(df["CH4"], df["predicted CH4"], squared = False)  
    
    return df, {"formula": formula,
                "r2": r2,
                "rmse": rmse,
                "f": mod.fvalue,
                "dataset": df_name}, mod

#%%
if run_combinations:
    df = inside
    bg_formula = "np.log(tgs2611) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O) + np.log(tgs2600)"
    
    products = product(["H2O", "np.log(H2O)", ""],
                       ["temp_c", "np.log(temp_c)", ""],
                       ["elapsed_time", "np.log(elapsed_time + 0.0001)", ""],
                       ["tgs2600", "np.log(tgs2600)", ""],
                       ["tgs2611", "np.log(tgs2611)", ""],
                       ["ratio", "np.log(ratio)", ""],
                       ["background", "np.log(background)", ""])
    
    formulae = [re.sub("\++", "+", "+".join(p)).strip("+") for p in products][:-1]
    
    results = []
    target = "CH4"
    
    for df, df_name in [[inside, "inside"],
                        [outside, "outside"],
                        [both, "both"]]:        
        df, _, _, _ = time_bg_fit(10, bg_formula, df)
    
    i = 0
    for f in formulae:
        if i % 10 == 0:
            print(i, "/", len(formulae), f)
        i += 1
        for df, df_name in [[inside, "inside"],
                            [outside, "outside"],
                            [both, "both"]]:
            results.append(methane_fit(df, df_name, f, target, True)[1])
            results.append(methane_fit(df, df_name, f, target, False)[1])
    
    results = pd.DataFrame(results).sort_values(by = "rmse", ascending = True)

#%%
if run_combinations:
    best_methane = results.groupby("dataset").head(5)
    in_methane = results[results["dataset"] == "inside"]
    out_methane = results[results["dataset"] == "outside"]
    both_methane = results[results["dataset"] == "both"]

#%%
#
# calibration route
#

def cal_route(bg_formula, target, time_fit = True, df = inside, plot = True):
    df = df.copy()
    df["delta CH4 (previous)"] = df["CH4"].diff()
    df["delta CH4 (next)"] = df["CH4"].diff(-1)
    df["time gap (previous)"] = df["datetime"].diff()
    df["time gap (next)"] = df["datetime"].shift(-1) - df["datetime"]
    
    if time_fit:
        df, bgr2, bgrmse, _ = time_bg_fit(10, bg_formula, df, target = target, plot = False)
    else:
        df, bgr2, bgrmse, _ = background_fit(formula = bg_formula, lh_trans = np.exp,
                                             df = df, plot = False, target = target)
    
    mod = smf.ols(formula = "ratio ~ CH4", data = df).fit()
    ratio_coeff = 1/mod.params["CH4"]
    intercept = (-1 * mod.params["Intercept"])/mod.params["CH4"]
    
    df["predicted CH4"] = df["ratio"] * ratio_coeff + intercept
    
    r2 = r2_score(df["CH4"], df["predicted CH4"])
    rmse = mean_squared_error(df["CH4"], df["predicted CH4"], squared = False)
    
    df.loc[:,"error"] = df["predicted CH4"] - df["CH4"]
    df.loc[:,"abs error"] = df["error"].abs()
    df.loc[:,"outlier"] = df["abs error"] > df["abs error"].quantile(0.99)
    
    df["CH4 change"] = df[["delta CH4 (previous)", "delta CH4 (next)"]].abs().max(axis = 1)
    df.loc[df["time gap (previous)"] > pd.Timedelta(11, "minutes"), "Condition"] = "Gap"
    df.loc[df["datetime"].between("march 17 2023 16:00", "march 17 2023 18:10"), "March weirdness"] = "March 17th"
    df.loc[df["datetime"].between("march 19 2023 11:00", "march 19 2023 12:20"), "March weirdness"] = "March 19th"
    df.loc[df["datetime"].between("march 22 2023 13:50", "march 22 2023 21:40"), "March weirdness"] = "March 22nd"
    df.loc[df["Condition"].isna(), "Condition"] = "Normal"
    
    df["March weirdness"] = df["March weirdness"].replace("nan", pd.NA)
    
    no_outliers = df[~df["outlier"]]
    outliers = df[df["outlier"]]
    
    outliers.to_csv("outliers.csv")
    
    r2_no_outlier = r2_score(no_outliers["CH4"], no_outliers["predicted CH4"])
    rmse_no_outlier = mean_squared_error(no_outliers["CH4"],
                                         no_outliers["predicted CH4"], squared = False)
    
    if not plot:
        return df, outliers, r2, rmse, r2_no_outlier, rmse_no_outlier, bgr2, bgrmse
    
    fig = plt.figure(figsize = (20, 13))
    gs = gridspec.GridSpec(2, 6)
    
    ax1 = plt.subplot(gs[0, 1:3])
    ax2 = plt.subplot(gs[0, 3:5], sharex = ax1, sharey = ax1)
    
    ax3 = plt.subplot(gs[1, 0:2], sharex = ax1, sharey = ax1)
    ax4 = plt.subplot(gs[1, 2:4], sharex = ax1, sharey = ax1)
    ax5 = plt.subplot(gs[1, 4:6], sharex = ax1, sharey = ax1)
    
    gs.tight_layout(fig, h_pad = 1.5, w_pad = 0)
    
    l = sn.scatterplot(data = df, x = "CH4", y = "predicted CH4", s = 20,
                   edgecolor = "none", color = "black", legend = False, ax = ax1)
    
    l.set_ylabel("Predicted $CH_4$ (ppm)")
    l.set_title("A: $CH_4$ regression fit")
    
    r = sn.scatterplot(data = df, x = "CH4", y = "predicted CH4", s = 20,
                   edgecolor = "none", hue = "outlier", palette = ["black", "orange"],
                   legend = False, ax = ax2)
    r.set_title("B: Regression outliers")
    
    sn.scatterplot(data = df.sort_values(by = "CH4 change", ascending = True), 
                   x = "CH4", y = "predicted CH4", s = 20,
                   edgecolor = "none", hue = "CH4 change", size = "CH4 change",
                   palette = sn.light_palette("black", as_cmap=True), 
                   legend = False, ax = ax3)
    ax3.set_title("C: $CH_4$ change (ppm per 10 minutes)")
    norm = plt.Normalize(df["CH4 change"].min(), df["CH4 change"].max())
    cmap = sn.light_palette("black", as_cmap = True)
    sm = plt.cm.ScalarMappable(cmap = cmap, norm = norm)
    sm.set_array([])
    
    cb = ax3.figure.colorbar(sm, ax = ax3, location = "right", anchor = (1, 0.5), 
                             fraction = 0.05, pad = 0)
    cb.ax.yaxis.set_ticks_position("right")
    
    sn.scatterplot(data = df[df["Condition"] != "Gap"], x = "CH4", y = "predicted CH4", s = 20,
                   edgecolor = "none", color = "black", legend = False, ax = ax4)
    sn.scatterplot(data = df[df["Condition"] == "Gap"], x = "CH4", y = "predicted CH4", s = 20,
                   edgecolor = "none", color = "red", 
                   legend = False, ax = ax4)
    ax4.set_title("D: After data gap")
    
    sn.scatterplot(data = df[df["March weirdness"].isna()], x = "CH4", y = "predicted CH4", s = 20,
                   edgecolor = "none", color = "black", legend = False, ax = ax5)
    m = sn.scatterplot(data = df[df["March weirdness"].notna()], x = "CH4", y = "predicted CH4", s = 20,
                       edgecolor = "none", hue = "March weirdness", legend = True, ax = ax5)
    m.legend(title = "", frameon = False, markerscale = 2)
    ax5.set_title("E: March events")
    
    for a in [ax1, ax2, ax3, ax4, ax5]:
        a.plot([df["CH4"].min(), df["CH4"].max()], 
               [df["CH4"].min(), df["CH4"].max()], 
               color='red',  linewidth=1.0)
        
        a.set_xlabel(None)
    
    for a in [ax2, ax4, ax5]:
        a.set_ylabel(None)
        plt.setp(a.get_yticklabels(), visible=False)
    
    ax3.set_ylabel("Predicted $CH_4$ (ppm)")
    ax4.set_xlabel("Actual $CH_4$ (ppm)")
    
    plt.show()
    
    return df, outliers, r2, rmse, r2_no_outlier, rmse_no_outlier, bgr2, bgrmse

bg_formula = "np.log(tgs2611) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O) + np.log(tgs2600)"
target = "tgs2611"
(df, outliers, r2, rmse, r2_no_outlier,
 rmse_no_outlier, bgr2, bgrmse) = cal_route(bg_formula, target, time_fit = True, plot = True)

print("No. outliers: ", len(outliers))
print("March events: ", len(outliers[~outliers["March weirdness"].isna()])/len(outliers) * 100, "%")
print("Gap: ", len(outliers[outliers["Condition"] == "Gap"])/len(outliers) * 100, "%")
print("> 1 ppm change: ", len(outliers[outliers["CH4 change"] > 1])/len(outliers) * 100, "%")
print("> 2 ppm change: ", len(outliers[outliers["CH4 change"] > 2])/len(outliers) * 100, "%")
print("> 5 ppm change: ", len(outliers[outliers["CH4 change"] > 5])/len(outliers) * 100, "%")
print("None of the above: ", len(outliers[(outliers["March weirdness"].isna()) &
                                          (outliers["Condition"] != "Gap") &
                                          (outliers["CH4 change"] < 1)])/len(outliers) * 100, "%")
print()
print("Gap, pct of all datapoints: ", 
      len(df[df["Condition"] == "Gap"])/len(df) * 100, "%")

#%%
#
# time series of outliers
#

bg_formula = "np.log(tgs2611) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O) + np.log(tgs2600)"
target = "tgs2611"

df, outliers, r2, rmse, r2_no_outlier, rmse_no_outlier, bgr2, bgrmse = cal_route(bg_formula, target, True)

f, axes = plt.subplots(3, 1, sharex = True, sharey = False, figsize = (20, 12))

def plot_ts(which, name, color, ax, ticks):
    s = sn.scatterplot(data = inside, x = "datetime", y = which, color = color, ax = ax,
                       s = 10, edgecolor = "none", legend = False)
    s.yaxis.label.set_color(color)
    s.set_ylabel(name)
    s.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins = ticks, min_n_ticks = ticks))
    return s

plot_ts("temp_c", "Temperature\n(°C)", "red", axes[0], 4)
#plot_ts("rh", "RH (%)", "blue", axes[0].twinx(), 4)
plot_ts("CH4", "$\\mathrm{CH_4}$ (ppm)", "black", axes[1], 4)
plot_ts("H2O", "$\\mathrm{H_2O}$ (%)", "blue", axes[0].twinx(), 4)
plot_ts("tgs2600", "TGS2600 (kΩ)", "red", axes[2].twinx(), 4)
plot_ts("tgs2611", "TGS2611-E00\n(kΩ)", "black", axes[2], 4)

axes[2].xaxis.set_minor_locator(mdates.DayLocator())
axes[2].xaxis.set_major_locator(mdates.DayLocator(bymonthday = 1))
axes[2].xaxis.set_major_formatter(mdates.DateFormatter("%-d %b"))
axes[2].set_xlabel("")

for a in axes:
    for o in outliers.iterrows():
        a.axvspan(mdates.date2num(o[1]["datetime"]),
                  mdates.date2num(o[1]["datetime"] + pd.Timedelta(10, "minutes")),
                  color = "purple", alpha = 0.2)
plt.show()


march_22 = inside[inside["datetime"].between("march 22 2023 0:01am", "march 23 2023 11:59pm")]

f, axes = plt.subplots(3, 1, sharex = True, sharey = False, figsize = (20, 12))

def plot_ts(which, name, color, ax, ticks):
    s = sn.scatterplot(data = march_22, x = "datetime", y = which, color = color, ax = ax,
                       s = 50, edgecolor = "none", legend = False)
    s.yaxis.label.set_color(color)
    s.set_ylabel(name)
    s.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins = ticks, min_n_ticks = ticks))
    return s

plot_ts("temp_c", "Temperature\n(°C)", "red", axes[0], 4)
#plot_ts("rh", "RH (%)", "blue", axes[0].twinx(), 4)
plot_ts("CH4", "$\\mathrm{CH_4}$ (ppm)", "black", axes[1], 4)
plot_ts("H2O", "$\\mathrm{H_2O}$ (%)", "blue", axes[0].twinx(), 4)
plot_ts("tgs2600", "TGS2600 (kΩ)", "red", axes[2].twinx(), 4)
plot_ts("tgs2611", "TGS2611-E00\n(kΩ)", "black", axes[2], 4)

axes[2].xaxis.set_minor_locator(mdates.DayLocator())
axes[2].xaxis.set_major_locator(mdates.DayLocator(bymonthday = 1))
axes[2].xaxis.set_major_formatter(mdates.DateFormatter("%-d %b"))
axes[2].set_xlabel("")
plt.show()

#%%
#
# time series for march 17, 19, 22
#

# outlier ranges:
# march 17 16:00 to 18:10
# march 19 11:00 to 12:20
# march 22 13:50 to 21:40

ol_ranges = [["march 17 2023 16:00", "march 17 2023 18:10"],
             ["march 19 2023 11:00", "march 19 2023 12:20"],
             ["march 22 2023 13:50", "march 22 2023 21:40"]]

mult = 1.5

dt_ranges = [[pd.to_datetime(r[0]), pd.to_datetime(r[1])] for r in ol_ranges]
dt_ranges = [[r[0] - (r[1] - r[0]) * mult, r[1] + (r[1] - r[0]) * mult] for r in dt_ranges]

ranges = [[r[0].strftime("%D %H:%M"), r[1].strftime("%D %H:%M")] for r in dt_ranges]

ratios = [r[1] - r[0] for r in dt_ranges]
ratios = [r/max(ratios) for r in ratios]

def plot_ts(which, name, color, ax, ticks, data, ylabel = True):
    s = sn.scatterplot(data = data, x = "datetime", y = which, color = color, ax = ax,
                       s = 50, edgecolor = "none", legend = False)
    if ylabel:
        s.yaxis.label.set_color(color)
        s.set_ylabel(name)
    else:
        s.set_ylabel(None)
        s.set_yticklabels([])
    s.yaxis.set_major_locator(MaxNLocator(prune='lower', nbins = ticks, min_n_ticks = ticks))
    return s

f, axes = plt.subplots(3, 3, sharex = "col", sharey = "row",
                       gridspec_kw={'width_ratios': ratios}, figsize = (30, 15))
plt.subplots_adjust(wspace = 0.05)

h2os = [axes[0, i].twinx() for i in range(3)]
tgs2600s = [axes[2, i].twinx() for i in range(3)]

for i in range(3):
    df = inside[inside["datetime"].between(*ranges[i])]
    
    right_ylabel = (i == 2)
    left_ylabel = True
    
    plot_ts("temp_c", "Temperature\n(°C)", "red", axes[0, i], 4, df, left_ylabel)
    #plot_ts("rh", "RH (%)", "blue", axes[0].twinx(), 4)
    plot_ts("CH4", "$\\mathrm{CH_4}$ (ppm)", "black", axes[1, i], 4, df, left_ylabel)
    plot_ts("H2O", "$\\mathrm{H_2O}$ (%)", "blue", h2os[i], 4, df, right_ylabel)
    plot_ts("tgs2600", "TGS2600 (kΩ)", "green", tgs2600s[i], 4, df, right_ylabel)
    plot_ts("tgs2611", "TGS2611-E00\n(kΩ)", "black", axes[2, i], 4, df, left_ylabel)
    
    axes[2, i].xaxis.set_minor_locator(mdates.HourLocator())
    axes[2, i].xaxis.set_major_locator(mdates.HourLocator(interval = 3))
    axes[2, i].xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
    axes[2, i].set_xlabel("")
    
    axes[0, i].set_title(dt_ranges[i][0].strftime("%Y-%m-%d"))

ylim = [min([a.get_ylim()[0] for a in h2os]),
        max([a.get_ylim()[1] for a in h2os])]
h2os[0].set_ylim(ylim)
h2os[1].set_ylim(ylim)
h2os[2].set_ylim(ylim)

ylim = [min([a.get_ylim()[0] for a in tgs2600s]),
        max([a.get_ylim()[1] for a in tgs2600s])]
tgs2600s[0].set_ylim(ylim)
tgs2600s[1].set_ylim(ylim)
tgs2600s[2].set_ylim(ylim)

for i in range(3):
    for j in range(3):
        axes[j, i].axvspan(mdates.date2num(pd.to_datetime(ol_ranges[i][0])),
                           mdates.date2num(pd.to_datetime(ol_ranges[i][1])),
                           color = "blue", alpha = 0.1)

plt.show()

#%%
#
#
#

formulae = [["np.log(tgs2611) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O)",
             "Eq. 1"],
            ["np.log(tgs2611) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O) + np.log(tgs2600)",
             "Eq. 2"]]

target = "tgs2611"
results = pd.DataFrame()

for formula, name in formulae:
    (df, outliers, r2, rmse, r2_no_outlier,
     rmse_no_outlier, bgr2, bgrmse) = cal_route(formula, target, time_fit = False, plot = False)
    
    df = df[["CH4", "predicted CH4"]]
    df["which"] = name + ", full fit"
    results = pd.concat([results, df])
    print(name + ", full: ", r2, rmse)
    
    (df, outliers, r2, rmse, r2_no_outlier,
     rmse_no_outlier, bgr2, bgrmse) = cal_route(formula, target, time_fit = True, plot = False)
    
    df = df[["CH4", "predicted CH4"]]
    df["which"] = name + ", piecewise fit"
    results = pd.concat([results, df])
    print(name + ", piecewise: ", r2, rmse)

results["which"] = results["which"].replace({"Eq. 1, full fit": "A: Eq. 1, full fit",
                                             "Eq. 1, piecewise fit": "B: Eq. 1, piecewise fit",
                                             "Eq. 2, full fit": "C: Eq. 2, full fit",
                                             "Eq. 2, piecewise fit": "D: Eq. 2, piecewise fit"})

g = sn.FacetGrid(data = results, col = "which", col_wrap = 2,
                 height = 8, aspect = 1.1, sharex = True, sharey = True,
                 despine = False)
g.set_titles(col_template = "{col_name}", y = 0.9)
g.map(sn.scatterplot, "CH4", "predicted CH4",
      color = "black", edgecolor = "none", s = 10)

def one_to_one(*args, **kwargs):
    x = (2, 10)
    y = (2, 10)
    plt.plot(y, x, color = "red")

g.map(one_to_one)

g.set_xlabels("Actual $CH_4$ (ppm)")
g.set_ylabels("Predicted $CH_4$ (ppm)")

plt.subplots_adjust(hspace=0, wspace=0)

plt.show()

#%%
#
# TGS2600 instead of 2611 - does it catch methane?
#

bg_formula = "np.log(tgs2600) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O) + np.log(tgs2611)"
target = "tgs2600"
(df, outliers, r2, rmse, r2_no_outlier,
 rmse_no_outlier, bgr2, bgrmse) = cal_route(bg_formula, target, time_fit = True, plot = True)

print("with 2611 in bg, piecewise bg: ", bgr2, bgrmse, r2, rmse)

bg_formula = "np.log(tgs2600) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O)"
target = "tgs2600"
(df, outliers, r2, rmse, r2_no_outlier,
 rmse_no_outlier, bgr2, bgrmse) = cal_route(bg_formula, target, time_fit = True, plot = True)

print("without 2611 in bg, piecewise bg: ", bgr2, bgrmse, r2, rmse)

#%%
#
# outdoor
#

bg_formula = "np.log(tgs2611) ~ temp_c + np.log(elapsed_time + 0.0001) + np.log(H2O) + np.log(tgs2600)"
target = "tgs2611"

df = outside.copy()
df["delta CH4 (previous)"] = df["CH4"].diff()
df["delta CH4 (next)"] = df["CH4"].diff(-1)
df["time gap (previous)"] = df["datetime"].diff()
df["time gap (next)"] = df["datetime"].shift(-1) - df["datetime"]

df, bgr2, bgrmse, _ = time_bg_fit(10, bg_formula, df, target = target, plot = False)

mod = smf.ols(formula = "ratio ~ CH4", data = df).fit()
ratio_coeff = 1/mod.params["CH4"]
intercept = (-1 * mod.params["Intercept"])/mod.params["CH4"]

df["predicted CH4"] = df["ratio"] * ratio_coeff + intercept

r2 = r2_score(df["CH4"], df["predicted CH4"])
rmse = mean_squared_error(df["CH4"], df["predicted CH4"], squared = False)

df.loc[:,"error"] = df["predicted CH4"] - df["CH4"]
df.loc[:,"abs error"] = df["error"].abs()
df.loc[:,"outlier"] = df["abs error"] > df["abs error"].quantile(0.99)

no_outliers = df[~df["outlier"]]
outliers = df[df["outlier"]]

outliers.to_csv("outliers.csv")

r2_no_outlier = r2_score(no_outliers["CH4"], no_outliers["predicted CH4"])
rmse_no_outlier = mean_squared_error(no_outliers["CH4"],
                                     no_outliers["predicted CH4"], squared = False)

df["CH4 change"] = df[["delta CH4 (previous)", "delta CH4 (next)"]].abs().max(axis = 1)
df.loc[df["time gap (previous)"] > pd.Timedelta(11, "minutes"), "Condition"] = "Gap"
df.loc[df["datetime"].between("march 17 2023 16:00", "march 17 2023 18:10"), "March weirdness"] = "March 17th"
df.loc[df["datetime"].between("march 19 2023 11:00", "march 19 2023 12:20"), "March weirdness"] = "March 19th"
df.loc[df["datetime"].between("march 22 2023 13:50", "march 22 2023 21:40"), "March weirdness"] = "March 22nd"
df.loc[df["Condition"].isna(), "Condition"] = "Normal"

fig = plt.figure(figsize = (20, 13))
gs = gridspec.GridSpec(2, 6)

ax1 = plt.subplot(gs[0, 1:3])
ax2 = plt.subplot(gs[0, 3:5], sharex = ax1, sharey = ax1)

ax3 = plt.subplot(gs[1, 0:2], sharex = ax1, sharey = ax1)
ax4 = plt.subplot(gs[1, 2:4], sharex = ax1, sharey = ax1)
ax5 = plt.subplot(gs[1, 4:6], sharex = ax1, sharey = ax1)

gs.tight_layout(fig, h_pad = 1.5, w_pad = 0)

l = sn.scatterplot(data = df, x = "CH4", y = "predicted CH4", s = 20,
               edgecolor = "none", color = "black", legend = False, ax = ax1)

l.set_ylabel("Predicted $CH_4$ (ppm)")
l.set_title("A: $CH_4$ regression fit")

r = sn.scatterplot(data = df, x = "CH4", y = "predicted CH4", s = 20,
               edgecolor = "none", hue = "outlier", palette = ["black", "orange"],
               legend = False, ax = ax2)
r.set_title("B: Regression outliers")

sn.scatterplot(data = df, x = "CH4", y = "predicted CH4", s = 20,
               edgecolor = "none", hue = "CH4 change",
               palette = sn.dark_palette("hotpink", as_cmap=True), 
               legend = False, ax = ax3)
ax3.set_title("C: $CH_4$ change (ppm per 10 minutes)")
norm = plt.Normalize(df["CH4 change"].min(), df["CH4 change"].max())
cmap = sn.dark_palette("hotpink", as_cmap = True)
sm = plt.cm.ScalarMappable(cmap = cmap, norm = norm)
sm.set_array([])

cb = ax3.figure.colorbar(sm, ax = ax3, location = "right", anchor = (1, 0.5), 
                         fraction = 0.05, pad = 0)
cb.ax.yaxis.set_ticks_position("right")

sn.scatterplot(data = df[df["Condition"] != "Gap"], x = "CH4", y = "predicted CH4", s = 20,
               edgecolor = "none", color = "black", legend = False, ax = ax4)
sn.scatterplot(data = df[df["Condition"] == "Gap"], x = "CH4", y = "predicted CH4", s = 20,
               edgecolor = "none", color = "red", 
               legend = False, ax = ax4)
ax4.set_title("D: After data gap")

sn.scatterplot(data = df[df["March weirdness"].isna()], x = "CH4", y = "predicted CH4", s = 20,
               edgecolor = "none", color = "black", legend = False, ax = ax5)
m = sn.scatterplot(data = df, x = "CH4", y = "predicted CH4", s = 20,
                   edgecolor = "none", hue = "March weirdness", legend = True, ax = ax5)
m.legend(title = "", frameon = False, markerscale = 2)
ax5.set_title("E: March events")

for a in [ax1, ax2, ax3, ax4, ax5]:
    a.plot([df["CH4"].min(), df["CH4"].max()], 
           [df["CH4"].min(), df["CH4"].max()], 
           color='red',  linewidth=1.0)
    
    a.set_xlabel(None)

for a in [ax2, ax4, ax5]:
    a.set_ylabel(None)
    plt.setp(a.get_yticklabels(), visible=False)

ax3.set_ylabel("Predicted $CH_4$ (ppm)")
ax4.set_xlabel("Actual $CH_4$ (ppm)")

plt.show()