#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This is the algorithm forecast code. It uses text metrics
to make forecasts.

The complete package of recorded settings for each run is given by
(metric, yvar, paper, horizon, alpha, stepSize, expanding, trafo, specification, model_name, CV)

Not that only 1 x and 1 y variable are recorded; the specification variable
captures other info such as factor model vs AR1

Individual model run output format:
metric | target_name | target_val | \hat{y}_OOS | \hat{y}_IS | alpha | trafo |
        model_name | paper | stepSize | date | specification | CV | expanding

"""
import pandas as pd
import os
import json
import configparser
import itertools
from tqdm import tqdm
from pathlib import Path
import numpy as np
from functools import reduce
import src.fcastutils as futil
import src.generalutils as gutil
import importlib

importlib.reload(futil)
importlib.reload(gutil)
# ---------------------------------------------------------------------------
# Settings & file paths
# ---------------------------------------------------------------------------
config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")


# ---------------------------------------------------------------------------
# Functions
# ---------------------------------------------------------------------------
def predictvsAR1(yvar, metric, horizon, paper, trafo, alpha, stepSize, expanding):
    """
    Model specification:
    y_{t+h} = \alpha + \beta\cdot y_{t-1} + \eta \cdot x_t + \epsilon_t
    (here, alpha is a constant in the reg not the initial period)
    Metric is x_t
    Returns data in format:
    IS_prediction	OOS_prediction	metric	target_name	paper	horizon
        specification	model	trafo	CV	alpha	expanding	target_value
            run_type
    date
    1998-05-31	3.571299	NaN	opinion	MGDP	GRDN	3
        met_vs_ar1	OLS	none	False	3	False	3.3
            metric
    """
    # Retrieving data ----------
    df = gutil.getTimeSeries(metric, yvar, paper, freq="M")
    # Model specification ----------------------------
    target = "target"
    lagString = yvar + "_lag_1"
    metricList = [metric]
    controls = [lagString]
    # One lag of yvar
    df[lagString] = df[yvar].shift(1)
    # Yvar horizon ahead
    df[target] = df[yvar].shift(-horizon)
    # Full run specification:
    model_name = "OLS"
    CV = False
    specification = "met_vs_ar1"
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model_name,
        specification,
    )
    finResults = futil.modelSpecWrap(
        df, metricList, controls, df[yvar], *packagedRunSettings
    )
    return finResults


def predictsvsFactorModel(
    yvar, metric, horizon, paper, trafo, alpha, stepSize, expanding, factors=2
):
    """
    Model:
        y_{t+h} = \alpha +  \beta \cdot y_{t-1} + \sum_j \gamma_j \cdot F_{jt}
                  + \eta \cdot x_t  +  \epsilon_t
    where F are factors.
    Outputs:
        Dataframe in following format

    IS_prediction	OOS_prediction	metric	target_name	horizon	specification/
        model	MGDP
    date
    1998-05-31	3.566262	NaN	opinion	MGDP	3	met_vs_fctrs	OLS	3.3
    1998-06-30	3.453735	NaN	opinion	MGDP	3	met_vs_fctrs	OLS	3.6
    1998-07-31	3.575717	NaN	opinion	MGDP	3	met_vs_fctrs	OLS	3.7
    """
    # There are only ten factors:
    assert factors < 10
    # Retrieving data ----------
    factorNamesList = ["PC_" + str(i) for i in range(1, factors + 1)]
    df = gutil.getTimeSeries(metric, [yvar] + factorNamesList, paper, freq="M")
    # Basic model specification ----------------------------
    target = "target"
    lagString = yvar + "_lag_1"
    controls = gutil.singleListCheck(factorNamesList) + [lagString]
    metricList = [metric]
    # One lag of yvar
    df[lagString] = df[yvar].shift(1)
    # Yvar horizon ahead
    df[target] = df[yvar].shift(-horizon)
    specification = "met_vs_fctrs"
    model = "OLS"
    CV = False
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
    )
    finResults = futil.modelSpecWrap(
        df, metricList, controls, df[yvar], *packagedRunSettings
    )
    return finResults


def loopOverSpecsAlg(specification, func_to_run):
    """
    Loops over all possible combinations of
    targets, metrics, newspapers, horizons, etc. given in config.ini,
    for a particular specification and associated forecast function.

    Although this will see if results already exist, it will not check if the
    config.ini file has changed—which can be important for the 'comb' (combination) newspaper.
    """
    papersL = json.loads(config.get("papers", "paper_list"))
    metricsL = list(gutil.allMetricsDict().keys())
    targetsL = list(gutil.allTargetsDict().keys())
    namesVars = json.loads(config.get("runSettings", "runSettings"))
    horizonsL = json.loads(config.get("runSettings", "horizons"))
    alpha, stepSize, trafo, expanding, CV = 36, 1, "none", False, False
    argsCombos = [
        targetsL,
        metricsL,
        horizonsL,
        papersL + ["COMB"],
        [trafo],
        [alpha],
        [stepSize],
        [expanding],
    ]
    argsCombos = list(itertools.product(*argsCombos))
    outPath = Path(os.path.join(config["data"]["results"], "ALL_" + specification + ".pkl"))
    outPathSum = Path(os.path.join(
        config["data"]["results"], "ALL_SUM_" + specification + ".pkl"
    ))
    for i, entry in tqdm(enumerate(argsCombos)):
        print(" \n -------------------------- \n")
        print("Iteration {} of {}".format(i + 1, len(argsCombos)))
        print("Running " + specification + " for " + "".join(str(entry)))
        # Check if file exists already; if not, create empty
        if(not outPath.is_file()):
            print("Results file not found; creating")
            allResults = pd.DataFrame()
            allSumResults = pd.DataFrame()
            allResults.to_pickle(outPath)
            allSumResults.to_pickle(outPathSum)
            in_results_already = False
        else:  # check whether this specification already in file
            allResults = pd.read_pickle(outPath)
            # build selection criterion
            spec_dict = dict(zip(["target", "metric", "horizon", "paper", "trafo", "alpha", "stepSize", "expanding"], entry))
            bool_index_entries = reduce(np.logical_and, [allResults[x] == y for x, y in spec_dict.items()])
            # If there exist no entries for this combo, sum ==0, so in_results_already returns False
            in_results_already = bool_index_entries.sum() != 0
        if(not in_results_already):
            try:
                resHere = func_to_run(*entry)
                # read existing results
                allResults = pd.read_pickle(outPath)
                # append new results
                allResults = pd.concat([allResults, resHere], axis=0)
                # Get summary stats for the run
                sumResHere = futil.runVSbchmarkData(resHere, entry[2], entry[-1])
                for i, col in enumerate(namesVars[: len(entry)]):
                    sumResHere[col] = entry[i]
                # read existing results
                allSumResults = pd.read_pickle(outPathSum)
                # append new results
                allSumResults = pd.concat([allSumResults, sumResHere], axis=0)
                # save all results to file
                allSumResults["specification"] = specification
                allSumResults["CV"] = False
                allSumResults["model"] = "OLS"
                allResults.to_pickle(outPath)
                allSumResults.to_pickle(outPathSum)
            except:
                ValueError("Forecast failed")
        else:
            print(f'Existing results found for specification {" & ".join([f"{x} == {y}" for x, y in spec_dict.items()])}')
    print("Finished!")


def run_all_AR1_text_metric_fcasts():
    loopOverSpecsAlg("met_vs_AR1", predictvsAR1)


def run_all_factor_model_text_metric_fcasts():
    loopOverSpecsAlg("met_vs_fctrs", predictsvsFactorModel)
