# This script provides *some* of the outputs used in the revisions for JAE.
# The others are in partial_dependence.py, use_individual_functions.py, or white_reality_test.py
import configparser
import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import json
import seaborn as sns
import src.generalutils as gutil

# Settings
mpl.rcParams.update(mpl.rcParamsDefault)  # VS Code plots not black
config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")
plt.style.use(config["viz"])
linestyles = [
    ":",
    "-.",
    "--",
    "-",
    (0, (3, 1, 1, 1, 1, 1)),
    (0, (5, 10)),
    (0, (3, 10, 1, 10)),
    (0, (3, 10, 1, 10, 1, 10)),
    (0, (3, 5, 1, 5)),
]

# Referee comment 1.1 and 1.3
# 1.1: `I find the result that the text based uncertainty measures does not
# increase much during the Global Financial Crisis a bit strange. [...] Will
# the result [of Fig 3] change if you also include words such as
# \emph{risk} and \emph{unpredictable}?
# 1.3: I would like to see how robust your sentiment and uncertainty measures
# (Fig. 1 and 3) are to different word dictionaries.''
#
# We address this by making use of our existing term-frequency data
# to plot some alternatives to our preferred measures of uncertainty, and of
# risk. Let's begin with uncertainty.


def address_r_one_one():
    paper = "GRDN"
    uncert_words_wanted = ["risk", "unpredictable", "precarious"]

    df = gutil.getTFTimeSeries([], paper, freq="M")

    u_words_available = [x for x in uncert_words_wanted if x in df.columns]
    df = df[u_words_available]

    # Now recreate the original plot but add the two extra indices
    inPath = os.path.join(config["data"]["results"], "ALL_swathes.csv")
    xf = pd.read_csv(
        inPath, parse_dates=True, index_col=0, infer_datetime_format=True, dayfirst=True
    )
    proxyName = "proxies_u"
    txtmetricName = "txtmetrics_u"
    xf = xf.loc[xf["paper"] == paper, :].drop("paper", axis=1)
    colsToUse = [x for x in xf.columns if x.split("_")[1] == "u"]
    xf = xf[colsToUse]
    xf = xf.dropna(how="any")
    xf = xf[txtmetricName + "_mean"]

    fig, ax = plt.subplots()
    ax.plot(xf, label="Uncertainty measure (GRDN)")
    tf = (df - df.mean(axis=0)) / df.std(axis=0)
    tf = tf.rolling(6).mean()
    tf = tf.loc[tf.index > xf.index.min()]
    for i, word in enumerate(tf.columns):
        ax.plot(tf[word], label=word, alpha=0.7, ls=linestyles[i])
    ax.legend()
    ax.set_ylabel("Standard deviations from the mean")
    ax.set_title(
        "One-word alternatives to preferred measure of uncertainty", loc="left"
    )
    plt.tight_layout()
    outPath = os.path.join(config["data"]["output"], "rev_fig_ref_1-1.eps")
    plt.savefig(outPath, dpi=300)
    plt.show()


# Referee comment 1.3
# 1.3: I would like to see how robust your sentiment and uncertainty measures
# (Fig. 1 and 3) are to different word dictionaries.
#
# We address this by breaking down our sentiment and uncertainty metrics and
# plotting them out as individual lines.
def address_r_one_three():
    papers = {"DAIM": "s", "GRDN": "u"}
    freq = "M"
    monthSmooth = 6
    norm = True  # Compare across series
    roll = False  # Actually true but do it w/in this function
    bmLocalTrafo = True  # Some bm data may not be meaningful w/out this
    for paper in papers:
        # Iterate over sentiment, uncertainty
        end = papers[paper]
        types = "txtmetrics"
        outPaper = pd.DataFrame()
        metrics_list = list(dict(config["txtmetrics_" + end]))
        xf = gutil.getTimeSeries(
            metrics_list, [], paper, norm=norm, roll=roll, freq=freq
        )
        xf = xf.rolling(monthSmooth).mean()
        baseSeries = xf[metrics_list].copy()
        negate_list = json.loads(config.get("negate", "negate_" + end))
        for col in baseSeries.columns:
            if col in negate_list:
                print("negating " + col + " for " + end)
                baseSeries[col] = -baseSeries[col]
        # Now pull in relevant swathe for this group of series
        inPath = os.path.join(config["data"]["results"], "ALL_swathes.csv")
        df = pd.read_csv(
            inPath,
            parse_dates=True,
            index_col=0,
            infer_datetime_format=True,
            dayfirst=True,
        )
        df = df.loc[df["paper"] == paper, :].drop("paper", axis=1)
        colsToUse = [x for x in df.columns if x.split("_")[1] == end]
        df = df[colsToUse]
        df = df.dropna(how="any")
        xf = xf[xf.index > df.index.min()]
        names_plots = {"s": "sentiment", "u": "uncertainty"}
        fig, ax = plt.subplots()
        for i, measure in enumerate(xf.columns):
            ax.plot(xf[measure], label=measure, ls=linestyles[i], lw=0.8)
        ax.plot(
            df["txtmetrics_" + end + "_mean"],
            lw=3,
            label="Composite measure",
            color="k",
        )
        ax.legend()
        ax.set_ylabel("Standard deviations from the mean")
        ax.set_title(
            "Composite measure versus contributing series: " + names_plots[end],
            loc="left",
        )
        plt.tight_layout()
        outPath = os.path.join(config["data"]["output"], "ref_1_3_" + end + ".eps")
        plt.savefig(outPath, dpi=300)
        plt.show()


##########################
# Hyperparameter tuning/CV
##########################
# These results were generated by Eleni using the use_individual_functions.py file on
# github and changing the settings to CV = True. It took a long time to run.
# The results are in data/results/GDP_testCV_Predtf_vs_fctrs_AR1OLS.csv
# NB: though the exact code to reproduce the csv isn't included in main.py, it can be done
# with use_individual_functions.py. These ouputs are not in main body of
# paper. Code below adds RMSE summary. Let's make the chart of this:
def ML_bar_chart_facet_hyperparam(saveAnalysis=True):
    """
    Produces facet grid of target, model, and RMSE ratio.
    Shows std over not shown variables (e.g. horizons and papers)
    """
    xf = pd.read_csv(
        os.path.join("data", "results", "GDP_testCV_Predtf_vs_fctrs_AR1OLS.csv")
    )
    # Only work out of sample:
    xf = xf[~pd.isna(xf["OOS_prediction"])]
    cols_to_find_uniques = [
        "target",
        "metric",
        "horizon",
        "paper",
        "trafo",
        "alpha",
        "stepSize",
        "expanding",
        "CV",
        "model",
        "run_type",
    ]
    # Create a place to store summarised data:
    sum_df = pd.DataFrame(xf.groupby(cols_to_find_uniques).size())
    sum_df = sum_df.reset_index().drop(0, axis=1)
    sum_df["rmse"] = np.nan

    # Find unique combinations in the data
    def unique_combos():
        combo_df = sum_df.drop(["rmse"], axis=1)
        # turn this into tuples
        combo_tuples = combo_df.apply(tuple, axis=1)
        return combo_tuples

    # Now loop over the combos of tuples and compute the rmse
    def compute_rmse(in_df):
        series_one = in_df.loc[:, "OOS_prediction"]
        series_two = in_df.loc[:, "target_value"]
        sqr = np.square(series_one - series_two)
        mean = np.mean(sqr)
        return np.sqrt(mean)

    combo_tuples = unique_combos()
    for spec in combo_tuples:
        query = " & ".join(
            [
                str(a) + '=="' + str(b) + '"'
                if type(b) == str
                else str(a) + "==" + str(b)
                for a, b in zip(cols_to_find_uniques, spec)
            ]
        )
        sum_df.loc[sum_df.eval(query), "rmse"] = compute_rmse(xf.query(query))

    # Now do spec list without the run_type
    res_df = pd.DataFrame(xf.groupby(cols_to_find_uniques[:-1]).size())
    res_df = res_df.reset_index().drop(0, axis=1)
    res_combos = res_df.apply(tuple, axis=1)
    res_df["RMSE/RMSE_bch"] = np.nan
    for spec in res_combos:
        query = " & ".join(
            [
                str(a) + '=="' + str(b) + '"'
                if type(b) == str
                else str(a) + "==" + str(b)
                for a, b in zip(cols_to_find_uniques[:-1], spec)
            ]
        )
        answers = sum_df.loc[sum_df.eval(query), :]
        res_df.loc[res_df.eval(query), "RMSE/RMSE_bch"] = (
            answers.loc[answers["run_type"] == "metric", "rmse"].values[0]
            / answers.loc[answers["run_type"] == "benchmark", "rmse"].values[0]
        )

    res_df = res_df.sort_values("RMSE/RMSE_bch")
    # Switch to nice names
    colsToNicify = ["target", "metric", "paper"]
    nicifyDict = gutil.nameConvert()
    nicifyDict.update({"COMB": "Mean"})
    for col in colsToNicify:
        res_df[col] = res_df[col].map(nicifyDict)
    convertColNamesD = dict(
        zip(
            colsToNicify + ["model"], [x.capitalize() for x in colsToNicify + ["model"]]
        )
    )
    res_df = res_df.rename(columns=convertColNamesD)
    res_df["Model"] = res_df["Model"].astype("category")
    # D Bholat suggested consistent ordering - changed to alphabetical
    y_order = sorted(res_df["Model"].unique())
    wrap_num = 3
    plt.close("all")
    plt.figure()
    g = sns.catplot(
        x="RMSE/RMSE_bch",
        y="Model",
        col="Target",
        data=res_df,
        kind="bar",
        # hue='Metric',
        height=4,
        legend=False,
        sharex=True,
        sharey=False,
        order=y_order,
        row_order=sorted(res_df["Target"].unique()),
        col_order=sorted(res_df["Target"].unique()),
        hue_order=y_order,
        facet_kws={"hue_order": y_order},
        ci="sd",
        col_wrap=wrap_num,
    )
    for i, axnow in enumerate(g.axes):
        axnow.axvline(x=1.0, color="k", ls="--")
        axnow.set_xlabel("")
        for item in [axnow.title, axnow.xaxis.label] + axnow.get_yticklabels():
            item.set_fontsize(15)
        axnow.set_xticks(np.arange(0.2, 2, step=0.4))
        if (i % wrap_num) != 0:
            labels = axnow.get_yticklabels()
            axnow.set_yticklabels([""] * len(labels))
    for ax in g.axes:
        ax.set_xlabel(r"$\frac{\mathrm{RMSE}}{\mathrm{RMSE}_{\mathrm{Bench.}}}$")
    if saveAnalysis:
        outPath = os.path.join(
            config["data"]["output"], "MLbarFacet_CV_GDP_DM_check" + ".eps"
        )
        plt.savefig(outPath, dpi=300)
    plt.show()
