#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: at

Defines a series of ML models to run on term-frequency matrices.

Individual model run output format:
metric | target_name | target_val | \hat{y}_OOS | \hat{y}_IS | alpha | trafo |
        model_name | paper | stepSize | date | specification | CV | expanding

"""
import src.fcastutils as futil
import src.generalutils as gutil
import configparser
import json
import importlib
import itertools
import os
import glob
import pandas as pd
from tqdm import tqdm
import numpy as np
from pathlib import Path
from functools import reduce
import warnings
from loguru import logger

importlib.reload(futil)
importlib.reload(gutil)


config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")
warnings.filterwarnings('ignore')

def predict_tf_vsAR1(
    yvar, horizon, paper, trafo, alpha, stepSize, expanding, CV, model
):
    """
    Model specification:
    y_{t+h} = f(y_{t-1},\vec{x}_t)
    (here, alpha is a constant in the reg not the initial period)
    Metric is x_t
    Returns data in format:
    IS_prediction	OOS_prediction	metric	target_name	paper	horizon
        specification	model	trafo	CV	alpha	expanding	target_value
            run_type
    date
    1998-05-31	3.571299	NaN	opinion	MGDP	GRDN	3
        met_vs_ar1	OLS	none	False	3	False	3.3
            metric
    """
    # Retrieving data ----------
    df = gutil.getTFTimeSeries(yvar, paper, freq="M")
    # Model specification ----------------------------
    target = "target"
    # This avoids a really tricky bug where target is the name of
    # one of the words tracked in the tf matrix
    if target in df.columns:
        df = df.rename(columns={target: target + "_word"})
    lagString = yvar + "_lag_1"
    metricList = [x for x in list(df.columns) if x != yvar]
    controls = [lagString]
    # One lag of yvar
    df[lagString] = df[yvar].shift(1)
    # Yvar horizon ahead
    df[target] = df[yvar].shift(-horizon)
    metric = "tf_matrix"
    specification = "tf_vs_AR1"
    # Package up settings:
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
    )
    # Forecast gubbins ----------------------------------
    finResults = futil.modelSpecWrap(
        df, metricList, controls, df[yvar], *packagedRunSettings
    )
    return finResults


def predict_tf_vsFactorModel(
    yvar, horizon, paper, trafo, alpha, stepSize, expanding, CV, model, factors=2
):
    """
    Model specification:
    y_{t+h} = f\left(y_{t-1},\vec{ x_t},\vec{F}_t\right)
    Metric is x_t
    Returns data in format:
    IS_prediction	OOS_prediction	metric	target_name	paper	horizon
        specification	model	trafo	CV	alpha	expanding	target_value
            run_type
    date
    1998-05-31	3.571299	NaN	opinion	MGDP	GRDN	3
        met_vs_ar1	OLS	none	False	3	False	3.3
            metric
    """
    # There are only ten factors:
    assert factors < 10
    factorNamesList = ["PC_" + str(i) for i in range(1, factors + 1)]
    # Retrieving data ----------
    df = gutil.getTFTimeSeries([yvar] + factorNamesList, paper, freq="M")
    # Model specification ----------------------------
    target = "target"
    # This avoids a really tricky bug where target is the name of
    # one of the words tracked in the tf matrix
    if target in df.columns:
        df = df.rename(columns={target: target + "_word"})
    # Basic model specification ----------------------------
    lagString = yvar + "_lag_1"
    controls = gutil.singleListCheck(factorNamesList) + [lagString]
    metricList = [x for x in list(df.columns) if x != yvar and x not in controls]
    # One lag of yvar
    df[lagString] = df[yvar].shift(1)
    # Yvar horizon ahead
    df[target] = df[yvar].shift(-horizon)
    specification = "tf_vs_fctrs"
    metric = "tf_matrix"
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
    )
    finResults = futil.modelSpecWrap(
        df, metricList, controls, df[yvar], *packagedRunSettings
    )
    return finResults


def predict_tf_vsAR1OLS(
    yvar, horizon, paper, trafo, alpha, stepSize, expanding, CV, model
):
    """
    ML versus an OLS AR(1) benchmark
    The fact that the benchmark uses OLS is reflected in the
    # specification only.
    Model specification:
    y_{t+h} = f\left(y_{t-1},\vec{ x_t}\right)
    against
    y_{t+h} = beta.y_{t-1} + epsilon
    Metric is x_t
    Returns data in format:
    IS_prediction	OOS_prediction	metric	target_name	paper	horizon
        specification	model	trafo	CV	alpha	expanding	target_value
            run_type
    date
    1998-05-31	3.571299	NaN	opinion	MGDP	GRDN	3
        met_vs_ar1	OLS	none	False	3	False	3.3
            metric
    """
    # Retrieving data ----------
    df = gutil.getTFTimeSeries(yvar, paper, freq="M")
    # Model specification ----------------------------
    target = "target"
    # This avoids a really tricky bug where target is the name of
    # one of the words tracked in the tf matrix
    if target in df.columns:
        df = df.rename(columns={target: target + "_word"})
    lagString = yvar + "_lag_1"
    metricList = [x for x in list(df.columns) if x != yvar]
    controls = [lagString]
    # One lag of yvar
    df[lagString] = df[yvar].shift(1)
    # Yvar horizon ahead
    df[target] = df[yvar].shift(-horizon)
    xf = df[yvar]
    specification = "tf_vs_AR1OLS"
    metric = "tf_matrix"
    # Model runs here - the metric one first
    run_type = "metric"
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        run_type,
    )
    allExog = metricList + controls
    target = "target"
    # As if real-time transformation of text metrics:
    df[metricList] = futil.txtMetricTransformer(df[metricList], alpha, stepSize, trafo)
    # Drop yvar to ensure no contamination
    df = df.drop(yvar, axis=1)
    # Forecast gubbins ----------------------------------
    # Get the datetime index that will be used throughout
    df = futil.prepForModelling(df, allExog + [target])
    inputDataIndex = df.index
    # Do not pass yvar through, only target
    # this is to protect from yvar accidentally being used in
    # fcasts
    # Ensure xf (holding yvar) has same index as df
    xf = pd.DataFrame(xf).loc[inputDataIndex, :]
    allResTxtmet = futil.forecastTool(
        df.loc[inputDataIndex, [target] + allExog],
        target,
        allExog,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
    )
    summarisedTxtMet = futil.collateRunResults(
        allResTxtmet, xf, inputDataIndex, *packagedRunSettings
    )
    # Now run the benchmark on the AR(1) exact same settings
    # - that includes the index from the first run to make it
    # a fair test
    # The only difference is that the model is forced to be OLS now
    # rather than be an ML model
    run_type = "benchmark"
    modelOLS = "OLS"
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        run_type,
    )
    allResBchmrk = futil.forecastTool(
        df.loc[inputDataIndex, [target] + controls],
        target,
        controls,
        alpha,
        stepSize,
        expanding,
        CV,
        modelOLS,
    )  # Important line
    # Keep the same packaged run settings for the collation of results
    # because we want the listing of the model variable to say the ML
    # model used. The fact that the benchmark uses OLS is reflected in the
    # specification only.
    summarisedBchmrk = futil.collateRunResults(
        allResBchmrk, xf, inputDataIndex, *packagedRunSettings
    )
    # Merge results from both phases
    finResults = pd.concat([summarisedTxtMet, summarisedBchmrk], axis=0)
    return finResults


def predict_tf_vsfctrOLS(
    yvar, horizon, paper, trafo, alpha, stepSize, expanding, CV, model, factors=2
):
    """
    ML versus an OLS factors + AR(1) benchmark
    The fact that the benchmark uses OLS is reflected in the
    # specification only.
    Model specification:
    y_{t+h} = f\left(y_{t-1},F_0,F_1, \vec{ x_t}\right)
    against
    y_{t+h} = beta.y_{t-1} + \sum_j \gamma_j F_j+ epsilon
    Metric is x_t
    Returns data in format:
    IS_prediction	OOS_prediction	metric	target_name	paper	horizon
        specification	model	trafo	CV	alpha	expanding	target_value
            run_type
    date
    1998-05-31	3.571299	NaN	opinion	MGDP	GRDN	3
        met_vs_ar1	OLS	none	False	3	False	3.3
            metric
    """
    # There are only ten factors:
    assert factors < 10
    factorNamesList = ["PC_" + str(i) for i in range(1, factors + 1)]
    # Retrieving data ----------
    df = gutil.getTFTimeSeries([yvar] + factorNamesList, paper, freq="M")
    # Model specification ----------------------------
    target = "target"
    # This avoids a really tricky bug where target is the name of
    # one of the words tracked in the tf matrix
    if target in df.columns:
        df = df.rename(columns={target: target + "_word"})
    lagString = yvar + "_lag_1"
    controls = gutil.singleListCheck(factorNamesList) + [lagString]
    metricList = [x for x in list(df.columns) if x != yvar and x not in controls]
    # One lag of yvar
    df[lagString] = df[yvar].shift(1)
    # Yvar horizon ahead
    df[target] = df[yvar].shift(-horizon)
    xf = df[yvar]
    specification = "tf_vs_fctrOLS"
    metric = "tf_matrix"
    # Model runs here - the metric one first
    run_type = "metric"
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        run_type,
    )
    allExog = metricList + controls
    target = "target"
    # As if real-time transformation of text metrics:
    df[metricList] = futil.txtMetricTransformer(df[metricList], alpha, stepSize, trafo)
    # Drop yvar to ensure no contamination
    df = df.drop(yvar, axis=1)
    # Forecast gubbins ----------------------------------
    # Get the datetime index that will be used throughout
    df = futil.prepForModelling(df, allExog + [target])
    inputDataIndex = df.index
    # Do not pass yvar through, only target
    # this is to protect from yvar accidentally being used in
    # fcasts
    # Ensure xf (holding yvar) has same index as df
    xf = pd.DataFrame(xf).loc[inputDataIndex, :]
    allResTxtmet = futil.forecastTool(
        df.loc[inputDataIndex, [target] + allExog],
        target,
        allExog,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
    )
    summarisedTxtMet = futil.collateRunResults(
        allResTxtmet, xf, inputDataIndex, *packagedRunSettings
    )
    # Now run the benchmark on the AR(1) exact same settings
    # - that includes the index from the first run to make it
    # a fair test
    # The only difference is that the model is forced to be OLS now
    # rather than be an ML model
    run_type = "benchmark"
    modelOLS = "OLS"
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        run_type,
    )
    allResBchmrk = futil.forecastTool(
        df.loc[inputDataIndex, [target] + controls],
        target,
        controls,
        alpha,
        stepSize,
        expanding,
        CV,
        modelOLS,
    )  # Important line
    # Keep the same packaged run settings for the collation of results
    # because we want the listing of the model variable to say the ML
    # model used. The fact that the benchmark uses OLS is reflected in the
    # specification only.
    summarisedBchmrk = futil.collateRunResults(
        allResBchmrk, xf, inputDataIndex, *packagedRunSettings
    )
    # Merge results from both phases
    finResults = pd.concat([summarisedTxtMet, summarisedBchmrk], axis=0)
    return finResults


def looper_tf_AR1():
    """
    Used to loop over AR(1) specifications when NOT using batch jobs in the cloud.
    """
    papersL = json.loads(config.get("papers", "paper_list"))[:2]
    targetsL = list(gutil.allTargetsDict().keys())[:2]
    modelsL = json.loads(config.get("runSettings", "MLmodels"))[2:]
    horizonsL = json.loads(config.get("runSettings", "horizons"))[:2]
    alpha, stepSize, trafo, expanding, CV = 36, 1, "none", False, False
    arg_combinations = [
        targetsL,
        horizonsL,
        papersL,
        [trafo],
        [alpha],
        [stepSize],
        [expanding],
        [CV],
        modelsL,
    ]
    arg_combinations = list(itertools.product(*arg_combinations))
    all_results = pd.DataFrame()
    for i, entry in tqdm(enumerate(arg_combinations)):
        print(" \n -------------------------- \n")
        print("Iteration {} of {}".format(i + 1, len(arg_combinations)))
        print("Running met vs AR1 for " + "".join(str(entry)))
        try:
            resHere = predict_tf_vsAR1(*entry)
            all_results = pd.concat([all_results, resHere], axis=0)
        except:
            ValueError("Forecast failed")
    # Save runs to file
    out_path = os.path.join(config["data"]["results"], "ALL_tf_vs_AR1.pkl")
    all_results.to_pickle(out_path)
    print("Finished!")
    return all_results


def generate_all_combinations_ML_fcasts():
    """
    Returns a list of tuples of all combinations of specifications for ML forecasts
    First entry is the specification, with codes:
    0=AR1, 1=factor model, 2 = OLSAR1, 3=OLS fctr.
    0 and 1 use ML as the benchmark
    2 and 3 use OLS as the benchmark
    """
    papersL = json.loads(config.get("papers", "paper_list"))
    targetsL = list(gutil.allTargetsDict().keys())
    modelsL = json.loads(config.get("runSettings", "MLmodels"))
    horizonsL = json.loads(config.get("runSettings", "horizons"))
    alpha, stepSize, trafo, expanding, CV = 36, 1, "none", False, False
    modelSpecs = [0, 1, 2, 3]  # 0=AR1, 1=factor model, 2 = OLSAR1, 3=OLS fctr
    arg_combinations = [
        modelSpecs,
        targetsL,
        horizonsL,
        papersL,
        [trafo],
        [alpha],
        [stepSize],
        [expanding],
        [CV],
        modelsL,
    ]
    arg_combinations = list(itertools.product(*arg_combinations))
    return arg_combinations


def run_individual_ML_spec(arg_combination: tuple):
    """
    Performs a single ML forecast with the given tuple of settings
    """
    metric = "tf_matrix"
    (
        spec_func_no,
        yvar,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
    ) = arg_combination
    if spec_func_no == 0:
        spec_func = predict_tf_vsAR1
    elif spec_func_no == 1:
        spec_func = predict_tf_vsFactorModel
    elif spec_func_no == 2:
        spec_func = predict_tf_vsAR1OLS
    elif spec_func_no == 3:
        spec_func = predict_tf_vsfctrOLS
    run_settings = (yvar, horizon, paper, trafo, alpha, stepSize, expanding, CV, model)
    try:
        results = spec_func(*run_settings)
    except:
        RuntimeError("Failed to run model")
    return results


def run_all_ML_model_fcasts():
    """
    Runs all specifications for tf matrices (using ML). This will take a very long time to run as it, by default,
    will do the moving window analysis for every given ML forecast specification.
    """
    out_path = Path(os.path.join(config["data"]["results"], config["tf_vs_ML"]["specFileStem"] + "_results.pkl"))
    arg_combinations = generate_all_combinations_ML_fcasts()
    for i, entry in tqdm(enumerate(arg_combinations)):
        logger.info("Iteration {} of {}".format(i + 1, len(arg_combinations)))
        logger.info("Running " + "".join(str(entry)))
        # Check if file exists already; if not, create empty
        if(not out_path.is_file()):
            print("Results file not found; creating")
            all_results = pd.DataFrame()
            all_results.to_pickle(out_path)
            in_results_already = False
        else:  # check whether this specification already in file
            all_results = pd.read_pickle(out_path)
            # build selection criterion
            spec_func_no = entry[0]
            specification = {0: "tf_vs_AR1", 1: "tf_vs_fctrs", 2: "tf_vs_AR1OLS", 3: "tf_vs_fctrOLS"}
            spec_dict = dict(zip(["specification", "target", "horizon", "paper", "trafo", "alpha", "stepSize", "expanding", "CV", "model"],
                                 [specification[spec_func_no]] + list(entry[1:])))
            bool_index_entries = reduce(np.logical_and, [all_results[x] == y for x, y in spec_dict.items()])
            # If there exist no entries for this combo, sum ==0, so in_results_already returns False
            in_results_already = bool_index_entries.sum() != 0
        if(not in_results_already):
            try:
                results = run_individual_ML_spec(entry)
                # read existing results
                all_results = pd.read_pickle(out_path)
                # append new results
                all_results = pd.concat([all_results, results], axis=0)
                all_results.to_pickle(out_path)
            except:
                RuntimeError("Forecast failed")
        else:
            print(f'Existing results found for specification {" & ".join([f"{x} == {y}" for x, y in spec_dict.items()])}')
    logger.info("Finished ML forecasts!")


def create_summary_ML_files():
    """
    Produces stat summaries of all runs using existing detailed data written to file.
    """
    df = pd.read_pickle(
        os.path.join(config["data"]["results"], "ALL_tf_ML_results.pkl")
    )
    specification_type = "tf_ML"
    arg_combinations = generate_all_combinations_ML_fcasts()
    df["expanding"] = df["expanding"].astype(bool)
    df["horizon"] = df["horizon"].astype(int)
    all_summry_results = pd.DataFrame()
    for i, entry in enumerate(arg_combinations):
        logger.info("Iteration {} of {}".format(i + 1, len(arg_combinations)))
        logger.info("Summarising " + specification_type + " for " + "".join(str(entry)))
        try:
            specification = {0: "tf_vs_AR1", 1: "tf_vs_fctrs", 2: "tf_vs_AR1OLS", 3: "tf_vs_fctrOLS"}
            spec_dict = dict(zip(["specification", "target", "horizon", "paper", "trafo", "alpha", "stepSize", "expanding", "CV", "model", "metric"],
                                 [specification[entry[0]]] + list(entry[1:]) + ["tf_matrix"]))
            bool_index_entries = reduce(np.logical_and, [df[x] == y for x, y in spec_dict.items()])
            # Pick out just the results for this combination of
            # settings
            resHere = df.loc[bool_index_entries, :]
            # Get summary stats for the run
            sumResHere = futil.runVSbchmarkData(resHere, spec_dict["horizon"], spec_dict["expanding"])
            for key, value in spec_dict.items():
                sumResHere[key] = value
            all_summry_results = pd.concat([all_summry_results, sumResHere], axis=0)
        except:
            ValueError("Could not summarise run")
    # Save run sumnmaries to file
    out_path_summry = os.path.join(
        config["data"]["results"], "ALL_SUM_" + specification_type + ".pkl"
    )
    all_summry_results.to_pickle(out_path_summry)
    logger.info("Finished ML summaries!")

##
## Batch job code for running on cloud
##

def generate_csv_tf_ML():
    """
    Iterates over possible model specifications and dumps
    them into a csv ready for batch processing.
    Used by batch jobs in the cloud.
    """
    arg_combinations = pd.DataFrame(generate_all_combinations_ML_fcasts())
    # Write to a csv file to be picked up one-by-one
    specFileText = os.path.join(
        config["data"]["outputscratch"],
        config["tf_vs_ML"]["specFileStem"] + "_specifications.csv",
    )
    arg_combinations.to_csv(specFileText, index=False)


def collate_ML_results():
    """
    Collate succesful runs from cloud batch jobs.
    """
    path = os.path.join(
        config["data"]["outputscratch"], config["tf_vs_ML"]["specFileStem"]
    )
    all_files = glob.glob(path + "_run_*.csv")
    li = []
    for filename in all_files:
        df = pd.read_csv(filename, index_col=0, header=0)
        li.append(df)
    frame = pd.concat(li, axis=0)
    frame.to_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_results.pkl",
        )
    )
    # Collate failed runs
    all_files = glob.glob(path + "_failrun_*.csv")
    li = []
    for filename in all_files:
        df = pd.read_csv(filename, index_col=0, header=0)
        li.append(df)
    if not li:
        # Do nothing, there were no failed runs
        print("No failed runs found")
    else:
        frame = pd.concat(li, axis=0)
        frame.to_csv(
            os.path.join(
                config["data"]["results"],
                config["tf_vs_ML"]["specFileStem"] + "_failed_runs.csv",
            )
        )

##
## Some jobs failed on cloud, and were re-run and appended to the combined results. These functions appended them in 
## a safety-first way.
##

def fold_in_NN_ML_results():
    """
    Convenience function for appending extra runs: here from the neural network forecasts. Rename resultsNEW after adding.
    """
    df = pd.read_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_results.pkl",
        )
    )
    xf = pd.read_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_resultsNN.pkl",
        )
    )
    df = pd.concat([df, xf], axis=0)
    df.to_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_resultsNEW.pkl",
        )
    )


def fold_in_OLSAR1_ML_results():
    """
    Convenience function for appending extra runs: here from the OLSAR1 forecasts. Rename resultsNEW after adding.
    """
    df = pd.read_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_results.pkl",
        )
    )
    xf = pd.read_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_resultsOLSAR1.pkl",
        )
    )
    df = pd.concat([df, xf], axis=0)
    df.to_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_resultsNEW.pkl",
        )
    )


def fold_in_OLSfctr_ML_results():
    """
    Convenience function for appending extra runs: here from the OLS factor model forecasts. Rename resultsNEW after adding.
    """
    df = pd.read_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_results.pkl",
        )
    )
    xf = pd.read_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_results_fctrOLS.pkl",
        )
    )
    df = pd.concat([df, xf], axis=0)
    df.to_pickle(
        os.path.join(
            config["data"]["results"],
            config["tf_vs_ML"]["specFileStem"] + "_resultsNEW.pkl",
        )
    )
