import matplotlib.pyplot as plt
import configparser
import json
import os
import pandas as pd
import itertools
import src.generalutils as gutil
import src.fcastutils as futil
from sklearn.model_selection import train_test_split
import numpy as np
from glob import glob
from pathlib import Path
from sklearn.inspection import partial_dependence
from sklearn.inspection import PartialDependenceDisplay


config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")
plt.style.use(config["viz"])


# =============================================================================
# REVISIONS: Partial importance exercise
# =============================================================================
# ----------------------------------------------
def create_partial_dependence_data():
    modelsL = json.loads(config.get("runSettings", "MLmodels"))
    alpha, stepSize, trafo, expanding, CV = 36, 1, "none", False, True

    targetsL = list(gutil.allTargetsDict().keys())
    for targetname in targetsL:
        targetname = "MGDP"
        argsCombos = [
            [targetname],
            [3],
            ["DAIM"],
            [trafo],
            [alpha],
            [stepSize],
            [expanding],
            [CV],
            modelsL,
        ]
        argsCombos = list(itertools.product(*argsCombos))
        argsCombos = [(argsCombos[1])]
        partval = pd.DataFrame()
        for setting in argsCombos:
            (
                yvar,
                horizon,
                paper,
                trafo,
                alpha,
                stepSize,
                expanding,
                CV,
                model,
            ) = setting
            print("Model name:   " + model)
            run_type = "metric"
            metric = "tf_matrix"
            specification = "tf_vs_FactorOLS"
            factorNamesList = ["PC_" + str(i) for i in range(1, 2 + 1)]
            # Retrieving data ----------
            df = gutil.getTFTimeSeries([yvar] + factorNamesList, paper, freq="M")
            # Model specification ----------------------------
            target = "target"
            # This avoids a really tricky bug where target is the name of
            # one of the words tracked in the tf matrix
            if target in df.columns:
                df = df.rename(columns={target: target + "_word"})
            lagString = yvar + "_lag_1"
            controls = gutil.singleListCheck(factorNamesList) + [lagString]
            metricList = [
                x for x in list(df.columns) if x != yvar and x not in controls
            ]
            # One lag of yvar
            df[lagString] = df[yvar].shift(1)
            # Yvar horizon ahead
            df[target] = df[yvar].shift(-horizon)
            xf = df[yvar]
            specification = "tf_vs_fctrOLS"
            metric = "tf_matrix"
            # Model runs here - the metric one first
            run_type = "metric"
            packagedRunSettings = (
                yvar,
                metric,
                horizon,
                paper,
                trafo,
                alpha,
                stepSize,
                expanding,
                CV,
                model,
                specification,
                run_type,
            )
            allExog = metricList + controls
            target = "target"
            # As if real-time transformation of text metrics:
            df[metricList] = futil.txtMetricTransformer(
                df[metricList], alpha, stepSize, trafo
            )
            # Drop yvar to ensure no contamination
            df = df.drop(yvar, axis=1)
            # Forecast gubbins ----------------------------------
            # Get the datetime index that will be used throughout
            df = futil.prepForModelling(df, allExog + [target])
            inputDataIndex = df.index
            # Do not pass yvar through, only target
            # this is to protect from yvar accidentally being used in
            # fcasts
            # Ensure xf (holding yvar) has same index as df
            xf = pd.DataFrame(xf).loc[inputDataIndex, :]

            X_train, X_test, Y_train, Y_test = train_test_split(
                df.loc[inputDataIndex, allExog],
                df.loc[inputDataIndex, [target]],
                test_size=0.2,
            )

            model_ml = futil.model_selection(model)
            model_ml.fit(X_train, Y_train)
            nox_values = [
                np.linspace(np.min(X_train[str(i)]), np.max(X_train[str(i)]))
                for i in X_train.columns
            ]
            df = pd.DataFrame.from_records(nox_values).T
            df.columns = X_train.columns
            newdf = pd.DataFrame()
            # For column in df
            for n, val in enumerate(df):
                pdp_values = []
                # For every value in that column
                for value in df[val]:
                    print(value)
                    # Take all of X_test and copy it
                    X_pdp = X_test.copy()
                    # Find the column and copy the value into every row
                    X_pdp[val] = value
                    pdp_values.append(np.mean(model_ml.predict(X_pdp)))
                newdf[val] = pdp_values

            var_imp = newdf.std()
            sumvar = var_imp.sum()
            var_imp = var_imp / sumvar
            #  var_imp =var_imp.sort_values(ascending = False)
            var_imp.name = model
            var_imp = var_imp.to_frame()
            var_imp.to_csv(
                os.path.join(
                    config["data"]["results"],
                    "partial_dependence_separate",
                    "PartDepValues"
                    + targetname
                    + "_"
                    + model
                    + "_"
                    + specification
                    + ".csv",
                )
            )
            partval = pd.concat([partval, var_imp], axis=1)
        partval.to_csv(
            os.path.join(
                config["data"]["results"],
                "partial_dependence_separate"
                "PartDepValues" + targetname + "_" + specification + ".csv",
            )
        )


def combine_partial_dependence_files():
    out_path = Path("data/results/partial_dependence_separate")
    file_list = list(out_path.glob("PartDepValues*.csv"))
    modelsL = json.loads(config.get("runSettings", "MLmodels"))
    df = pd.DataFrame()
    index_vars = ["target_var", "specification_fcast", "ml_model"]
    for file in file_list:
        if any([model in file.name for model in modelsL]):
            target = file.name.split("_")[0].lstrip("PartDepValues")
            spec = "_".join(file.name.split("_")[2:]).split(".")[0]
            model = file.name.split("_")[1]
            temp_df = pd.read_csv(file, index_col=0).T
            temp_df[index_vars] = target, spec, model
            temp_df = temp_df.set_index(index_vars)
            df = pd.concat([df, temp_df], axis=0)
        else:
            pass  # no model specified
    # Make consistent names for all lagged variables
    lag_cols = [x for x in df.columns if "lag_1" in x]
    df["lag_col"] = df[lag_cols].apply(
        lambda x: sum([y if pd.isna(y) == False else 0 for y in x]), axis=1
    )
    df = df.drop(lag_cols, axis=1)
    df.round(4).to_csv(Path("data/results") / "PartialDependence.csv")

    # =============================================================================
    # Plot variable importance metrics
    # =============================================================================


def plot_variable_importance_metrics():
    in_path = Path("data/results")
    df = pd.read_csv(in_path / "PartialDependence.csv")
    targetsL = ["MGDP", "CPIall", "LFSURATE"]
    modelsL = json.loads(config.get("runSettings", "MLmodels"))
    nicifyDict = gutil.nameConvert()
    for target in targetsL:
        xf = df.loc[df["target_var"] == target]
        fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 15), sharey=True)
        plt.subplots_adjust(hspace=0.75, wspace=0.1)
        for model, ax in zip(modelsL, axes.flatten()):
            sub_xf = xf.loc[xf["ml_model"] == model]
            bar_data = sub_xf[sub_xf.columns[3:]].T.nlargest(
                n=10, columns=sub_xf.T.columns
            )
            bar_data = bar_data.divide(bar_data.max())
            bar_data.plot.bar(ax=ax, title=model, legend=False)
        plt.suptitle(f"Variable Importance: {nicifyDict[target]}")
        specification = xf["specification_fcast"].unique()[0]
        plt.tight_layout()
        fig.savefig(
            os.path.join(
                config["data"]["output"],
                "ParDepGraphs" + target + "_" + specification + ".eps",
            )
        )
