# From editor comment R3.16 in revisions. For main paper, not just ref response
# Applies the following test:
#
# Test of Superior Predictive Ability (SPA) of White and Hansen.
#
# The SPA is also known as the Reality Check or Bootstrap Data Snooper.

import src.fcastutils as futil
import src.generalutils as gutil
import configparser
import json
import importlib
import itertools
import os
import pandas as pd
from tqdm import tqdm
import numpy as np
from arch.bootstrap import SPA

importlib.reload(futil)


config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")
output_dir = os.path.realpath(config["data"]["results"])
np.random.seed(23456)
# Common seed used throughout for SPA test
seed = 1382121707


def starFunc(numin):
    if numin < 0.01:
        ans = "***"
    elif numin < 0.05:
        ans = "** "
    elif numin < 0.1:
        ans = "*  "
    else:
        ans = "   "
    return f"{round(numin, 2):.2f}" + ans


def msefunc(seriesOne, seriesTwo):
    error = (seriesOne - seriesTwo) ** 2
    return error


def msefunctext(seriesOne, seriesTwo):
    models = [i for i in seriesTwo.columns]
    helper = seriesTwo.join(seriesOne)
    helper = helper.target_value.values[:, None] - helper[models]
    error = helper ** 2
    return error


def retrieve_ml_forecast(df, target, horizon):
    newdf = df[(df["target"] == target) & (df["horizon"] == horizon)].dropna()
    ycol = newdf.loc[~newdf.index.duplicated(keep="first"), "target_value"]
    ycol = ycol.sort_index()
    txtpred = newdf.loc[
        newdf["run_type"] == "metric", ["OOS_prediction", "model"]
    ].dropna()
    txtpred = txtpred.sort_index()
    txtpred = txtpred.set_index(["model"], append=True)
    txtpred = txtpred.pivot_table(
        values="OOS_prediction", index=["date"], columns="model"
    )
    bchpred = newdf.loc[newdf["run_type"] == "benchmark"].dropna()
    bchpred = bchpred.loc[~bchpred.index.duplicated(keep="first"), "OOS_prediction"]
    bchpred = bchpred.sort_index()
    return ycol, txtpred, bchpred


def reality_test(specification):
    """Runs the reality test on ML forecasts.

    :param specification: One of tf_vs_AR1OLS or tf_vs_fctrOLS
    :type specification: str
    """
    targetsL = list(gutil.allTargetsDict().keys())
    namesVars = json.loads(config.get("runSettings", "runSettings"))
    namesVars.remove("model")
    namesVars.remove("metric")
    horizonsL = json.loads(config.get("runSettings", "horizons"))
    alpha, stepSize, trafo, expanding, CV = 36, 1, "none", False, False
    argsCombos = [
        targetsL,
        horizonsL,
        ["COMB"],
        [trafo],
        [alpha],
        [stepSize],
        [expanding],
        [CV],
        [specification],
    ]
    argsCombos = list(itertools.product(*argsCombos))
    spec_dict = {
        "tf_vs_AR1OLS": "ALL_tf_ML_resultsOLSAR1",
        "tf_vs_fctrOLS": "ALL_tf_ML_results_fctrOLS",
    }
    df = pd.read_pickle(
        os.path.join(config["data"]["results"], spec_dict[specification] + ".pkl")
    )
    summariseddf = pd.DataFrame()
    for i, entry in tqdm(enumerate(argsCombos)):
        print(" \n -------------------------- \n")
        print("Iteration {} of {}".format(i + 1, len(argsCombos)))
        print("Running " + specification + " for " + "".join(str(entry)))
        target = entry[0]
        horizon = entry[1]
        trafo = entry[3]
        alpha = entry[4]
        expanding = entry[6]
        CV = entry[7]
        ycol, txtpred, bchpred = retrieve_ml_forecast(df, target, horizon)
        errtext = msefunctext(ycol, txtpred)
        errben = msefunc(ycol, bchpred)
        spa = SPA(benchmark=errben.values, models=errtext.values, seed=551, reps=5000)
        spa.compute()
        lower, consistent, upper = spa.pvalues
        dataCols = [lower, consistent, upper]
        colNames = ["lower", "consistent", "upper"]
        outDf = pd.DataFrame(data=[dataCols], columns=colNames)
        for i, col in enumerate(namesVars[: len(entry)]):
            outDf[col] = entry[i]
        summariseddf = pd.concat([summariseddf, outDf], axis=0)
    # Save output
    outPathcsv = os.path.join(
        config["data"]["results"], "reality_test_" + specification + ".csv"
    )
    summariseddf.to_csv(outPathcsv)


def write_out_reality_table(specification):
    """Outputs the results of a reality test to a table.

    :param specification: One of tf_vs_AR1OLS or tf_vs_fctrOLS
    :type specification: str
    """
    df = pd.read_csv(
        os.path.join(
            config["data"]["results"], "reality_test_" + specification + ".csv"
        ),
        index_col=0,
    )
    df = (
        pd.melt(
            df,
            id_vars=["horizon", "target"],
            value_vars=["lower", "consistent", "upper"],
        )
        .rename(columns={"variable": "type"})
        .pivot(index=["horizon", "type"], columns=["target"], values=["value"])
    )
    df.columns = df.columns.droplevel()
    nicifyDict = gutil.nameConvert()
    df = df.rename(columns=nicifyDict)
    outPath = os.path.join(
        config["data"]["output"], "reality_table" + specification + ".txt"
    )
    df.to_latex(outPath)


if __name__ == "__main__":
    # Here we want to create tables showing the SPA test run on:
    # - tf_vs_fctrOLS
    # - tf_vs_AR1OLS
    spec_list = ["tf_vs_AR1OLS", "tf_vs_fctrOLS"]
    for spec in spec_list:
        reality_test(spec)
        write_out_reality_table(spec)
