#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

This produces correlation tables, swathe plots,
tables of ADF, GC, etc. etc.

It does not produce forecasts.

Tables are exported directly to latex as .txt

Data for figures is output in csvs that are ingested later by figs
and tables.py

ALL = contains all newspapers but individually
COMB = contains all newspapers but aggregated

"""
import numpy as np
import pandas as pd
import os
import json
import configparser
from loguru import logger
import src.generalutils as gutil
import src.txtutils as txtutil
import statsmodels.tsa.stattools as smts
import importlib

importlib.reload(gutil)
# ---------------------------------------------------------------------------
# Settings & file paths
# ---------------------------------------------------------------------------
config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")

paper_list = json.loads(config.get("papers", "paper_list"))
# Get all metrics from config file. Will throw
# error if they don't all exist in dataframe
metrics_list = list(gutil.allMetricsDict().keys())
# ---------------------------------------------------------------------------
# Functions
# ---------------------------------------------------------------------------


def autocorrelationTable(df, seriesToTry, lag):
    """Returns the autocorrelation table containing seriesToTry
    up to specified lag
    """
    empList = []
    for i in range(1, lag + 1):
        empList.append("t-" + str(i))
    autocorrtable = pd.DataFrame(index=range(lag))
    for name in seriesToTry:
        autocorrtable[name] = pd.Series(
            [df[name].autocorr(i) for i in range(1, lag + 1)]
        )
    autocorrtable.index = empList
    return autocorrtable


def combinedAutocorrelation(paper_list, metrics, lag=5, save=False):
    """
    Given input series and source returns the autocorrelations up to
    a given lag.
    Returns a tidy dataframe that is optionally saved
    """
    fNames = [gutil.paperFname(paper) for paper in paper_list]
    freq = "M"
    auto_df = pd.DataFrame()
    for entry, paper in enumerate(paper_list):
        df = gutil.loadMetricsData(fNames[entry])
        df = df.groupby(pd.Grouper(freq=freq)).mean()
        table = autocorrelationTable(df, metrics, lag)
        table["newspaper"] = paper
        auto_df = pd.concat([table, auto_df], axis=0)
    auto_df = (
        auto_df.reset_index()
        .melt(id_vars=["newspaper", "index"], var_name="metric", value_name="value")
        .rename(columns={"index": "lag"})
    )
    saveString = os.path.join(config["data"]["results"], "ALL_autocorr.csv")
    if save:
        auto_df.to_csv(saveString)
    else:
        logger.info(auto_df)
    return auto_df


def combinedADFTest(paper_list, metrics, save=False):
    """
    Given input series and source returns dataframe of form
                                GRDN No. obs.           DMIR No. obs.
    Afinn_sentiment               -03.03**      313  -04.28***      247
    Harvard_sentiment            -04.17***      312  -07.83***      248
    Loughran_sentiment              -02.20      313  -04.77***      248
    """
    freq = "M"
    ADFTable = pd.DataFrame(index=metrics)
    saveADF = []
    fNames = [gutil.paperFname(paper) for paper in paper_list]
    nameDict = gutil.nameConvert()
    for entry, paper in enumerate(paper_list):
        df = gutil.loadMetricsData(fNames[entry])
        dfCounts = df[metrics].groupby(pd.Grouper(freq=freq)).count().mean(axis=1)
        # Only use time periods with more than 15 articles per month
        dfCounts = dfCounts[dfCounts > 15]
        xf = df[metrics].groupby(pd.Grouper(freq=freq)).mean()
        xf = xf.loc[dfCounts.index, :]
        xf = xf.dropna(how="all")
        ADFTable = gutil.augmentedDickeyFullerTestTable(xf, metrics).T
        ADFTable = ADFTable.rename(columns={"ADF statistic": nameDict[paper]})
        saveADF.append(ADFTable)
    appended_data = pd.concat(saveADF, axis=1)
    appended_data.index = appended_data.index.map(nameDict)
    saveString = os.path.join(config["data"]["output"], "ALL_ADFtable.txt")
    if save:
        appended_data.to_latex(saveString)
    else:
        logger.info(appended_data)
    return appended_data


def GCTableSort(GCMat):
    """
    Sorts GC table so that most significant results are in
    top left of matrix.
    """
    tempGCMat = GCMat.copy()
    for x in tempGCMat.columns:
        tempGCMat[x] = tempGCMat[x].str.replace("*", "")
        tempGCMat[x] = pd.to_numeric(tempGCMat[x])
    # collate the metrics in desceding order
    tempGCMat["average_corr"] = tempGCMat.mean(axis=1)
    tempGCMat = tempGCMat.sort_values(by="average_corr", ascending=False).drop(
        "average_corr", axis=1
    )
    sortedCols = (
        pd.DataFrame(tempGCMat.mean(axis=0)).sort_values(by=0, ascending=False).index
    )
    sortedIndex = tempGCMat.index
    GCMat = GCMat.reindex(columns=sortedCols).reindex(index=sortedIndex)
    nameDict = gutil.nameConvert()
    GCMat.index = GCMat.index.map(nameDict)
    GCMat = GCMat.rename(
        columns=dict(zip(GCMat.columns, [nameDict[x] for x in GCMat.columns]))
    )
    return GCMat


def GrangerCauseTable(cause_columns, affect_rows, newspaper):
    """This returns the results of a Granger Causality test of the
    effect of the cause_columns on the affect_rows.
        WARNING: the percentages which are starred are:
            1%: ***, 5%: **, 10%: *
    Args:
        df (pandas dataframe):  must contain cause_columns,affect_rows as
        columns
        cause_columns (list): to test for their effect on other series
        affect_rows (list): to test for being affected
        lag_value (int): the lag of the cause columns
    Returns:
        Pandas dataframe of test results with ROWS as causers
        and COLUMNS as caused. This inversion of input is to make for easier
        reading of results
    """
    norm = False  # Compare across series
    roll = False  # Actually true but do it w/in this function
    bmLocalTrafo = True  # Some bm data may not be meaningful w/out this
    lag_value = 3
    allProxies = list(gutil.allProxiesDict().keys())
    allMetrics = list(gutil.allMetricsDict().keys())
    if type(cause_columns) == str:
        cause_columns = [cause_columns]
    if type(affect_rows) == str:
        affect_rows = [affect_rows]
    resultsDictRow = {}
    for row in affect_rows:
        resultsDictCol = {}
        for col in cause_columns:
            if row in allProxies:
                df = gutil.getTimeSeries(
                    col, row, newspaper, norm=norm, roll=roll, bmLocalTrafo=bmLocalTrafo
                )
            else:
                df = gutil.getTimeSeries(
                    row, col, newspaper, norm=norm, roll=roll, bmLocalTrafo=bmLocalTrafo
                )
            logger.info("Data pulled for GCT test of " + col + " on " + row)
            resultsGC = smts.grangercausalitytests(
                df[[row, col]].dropna(), maxlag=lag_value, verbose=False
            )[lag_value]
            sigString = ""
            if resultsGC[0]["params_ftest"][1] < 0.01:
                sigString = "***"
            elif resultsGC[0]["params_ftest"][1] < 0.05:
                sigString = "**"
            elif resultsGC[0]["params_ftest"][1] < 0.1:
                sigString = "*"
            else:
                sigString = ""
            finString = "{:02.2f}".format(resultsGC[0]["params_ftest"][0])
            finString = finString + sigString
            resultsDictCol[col] = finString
        resultsDictRow[row] = resultsDictCol
    return pd.DataFrame.from_dict(resultsDictRow)


def GC_tables(paper_list, metrics_list, save=False):
    """
    Loops over and exports GC tables
    """
    outpath = config["data"]["output"]
    cause_columns = metrics_list
    affect_rows = list(gutil.allProxiesDict().keys())
    for paper in paper_list:
        logger.info("Running GCT on " + paper)
        outFiletxt = os.path.join(outpath, paper + "_GCT_txtcause.txt")
        outFilepxy = os.path.join(outpath, paper + "_GCT_pxycause.txt")
        tableTxtCaus = GrangerCauseTable(cause_columns, affect_rows, paper)
        logger.info("... one way GCT done ...")
        tablePxyCaus = GrangerCauseTable(affect_rows, cause_columns, paper)
        tableTxtCaus = GCTableSort(tableTxtCaus)
        tablePxyCaus = GCTableSort(tablePxyCaus)
        if save:
            tableTxtCaus.T.to_latex(outFiletxt)
            tablePxyCaus.T.to_latex(outFilepxy)
        logger.info(tableTxtCaus)
        logger.info(tablePxyCaus)


def writeMetTestTableToFile():
    """
    Writes the example run to file, incorporating its underlying text
    """
    testDataName = "TEST"
    fname = gutil.paperFname(testDataName)
    # Grab the scores
    df = gutil.loadMetricsData(fname)
    # Recombine them with the texts
    dftxt = txtutil.getNewsText(testDataName, nrows=None)
    df = pd.concat([df.reset_index(), dftxt], axis=1)
    colsToInlcude = [
        "text",
        "word_count_econom",
        "alexopoulos",
        "stability",
        "vader",
        "tf_idf_econom",
    ]
    df = df.drop([x for x in df.columns if x not in colsToInlcude], axis=1)
    nameDict = gutil.nameConvert()
    nameDict.update({"text": "Text"})
    df.rename(columns=nameDict, inplace=True)
    df = df.set_index("Text")
    outPath = os.path.join(config["data"]["output"], "testExamples.txt")
    with pd.option_context("max_colwidth", 100000):
        df.round(2).to_latex(outPath)


def swatheData(paper_list, save=False):
    """
    Creates means of text metrics min/max swathe from proxies.
    Does this over the paper list and includes all metrics.
    All proxies are included by default
    Performs an inner join -- only those dates for which text metrics
    exist are covered.
    Only uses months with>15 article entries
    Result is 3M MA of normalised measure
    Output to tidy data
    NOTE that swathe data are from the inner join of all series
    and so will miss out lots of time periods
    """
    freq = "M"
    monthSmooth = 3
    logger.info("Swathe data working on " + str(monthSmooth) + "smoothing")
    logger.info("\n and at " + freq + "frequency")
    outSwathes = pd.DataFrame()
    logger.info(
        "Note that this uses a strict all data must be present in cross"
        + "-section rule, so will chop ends off text metrics if proxies do"
        + " not exist."
    )
    norm = True  # Compare across series
    roll = False  # Actually true but do it w/in this function
    bmLocalTrafo = True  # Some bm data may not be meaningful w/out this
    for paper in paper_list:
        logger.info("\n --------- \n")
        logger.info("Creating swathe data for " + paper)
        # Iterate over sentiment, uncertainty
        txt_endings = ["u", "s"]
        types = ["proxies", "txtmetrics"]
        outPaper = pd.DataFrame()
        for i, end in enumerate(txt_endings):
            metrics_list = list(dict(config["txtmetrics_" + end]))
            nonmetrics = list(dict(config["proxies_" + end]))
            for j, typeNow in enumerate(types):
                xf = gutil.getTimeSeries(
                    metrics_list,
                    nonmetrics,
                    paper,
                    norm=norm,
                    roll=roll,
                    bmLocalTrafo=bmLocalTrafo,
                    freq=freq,
                )
                xf = xf.rolling(monthSmooth).mean()
                nameNow = typeNow + "_" + end
                baseSeries = xf[list(dict(config[nameNow]).keys())].copy()
                negate_list = json.loads(config.get("negate", "negate_" + end))
                for col in baseSeries.columns:
                    if col in negate_list:
                        logger.info("negating " + col + " for " + end)
                        baseSeries[col] = -baseSeries[col]
                xf[nameNow + "_mean"] = baseSeries.apply(np.nanmean, axis=1)
                xf[nameNow + "_max"] = baseSeries.apply(np.nanmax, axis=1)
                xf[nameNow + "_min"] = baseSeries.apply(np.nanmin, axis=1)
                xf[nameNow + "_median"] = baseSeries.apply(np.nanmedian, axis=1)
                allNames = [x for x in xf.columns if nameNow in x]
                outPaper = pd.concat([outPaper, xf[allNames]], axis=1)
        outPaper["paper"] = paper
        outSwathes = pd.concat([outPaper, outSwathes], axis=0)
    if save:
        outPath = os.path.join(config["data"]["results"], "ALL_swathes.csv")
        outSwathes.to_csv(outPath)
    else:
        logger.info(outSwathes)


def corrMatData(paper_list, metrics_list, save=False):
    """
    Creates correlation matrix
    Does this over the paper list and includes all metrics.
    All proxies in config.ini are included by default
    Inputs are the inner join of all text metrics
    Output to tidy data
    Produces:
    paper | metric | proxy | horizon | correlation
    NOTE that this uses the inner join of time series
    and so will miss out lots of rows
    """
    # Iterate over sentiment, uncertainty, horizon
    horizons = json.loads(config.get("runSettings", "horizons"))
    freq = "M"
    outCorr = pd.DataFrame()
    # Leave data unchanged except bm locals to get stationary series:
    norm = True  # Compare across series
    roll = False  # Actually true but do it w/in this function
    bmLocalTrafo = True  # Some bm data may not be meaningful w/out this
    for paper in paper_list:
        logger.info("\n --------- \n")
        logger.info("Creating corr data for " + paper)
        nonmetrics = list(gutil.allProxiesDict())
        xf = gutil.getTimeSeries(
            metrics_list,
            nonmetrics,
            paper,
            norm=norm,
            roll=roll,
            bmLocalTrafo=bmLocalTrafo,
            freq=freq,
        )
        for hori in horizons:
            temp_xf = xf.copy()
            # Shift proxies hori units ahead
            proxiesList = list(gutil.allProxiesDict().keys())
            for proxy in proxiesList:
                temp_xf[proxy] = temp_xf[proxy].shift(-hori)
            corrMat = (
                temp_xf[proxiesList + metrics_list]
                .corr()
                .loc[metrics_list, proxiesList]
            )
            corrMat = (
                corrMat.reset_index()
                .melt(id_vars=["index"], var_name="proxy", value_name="correlation")
                .rename(columns={"index": "metric"})
            )
            corrMat["horizon"] = hori
            corrMat["paper"] = paper
            outCorr = pd.concat([outCorr, corrMat], axis=0)
    if save:
        outPath = os.path.join(config["data"]["results"], "ALL_corrs.csv")
        outCorr.to_csv(outPath)
    else:
        logger.info(outCorr)


# ---------------------------------------------------------------------------
# Analysis
# ---------------------------------------------------------------------------
# Autocorrelation - papers, metrics, and up lag= lags
combinedAutocorrelation(paper_list, metrics_list, save=True)
# ADF Test
combinedADFTest(paper_list, metrics_list, save=True)
# Granger causality - four ways, and one per given newspaper
GC_tables(["COMB"], metrics_list, save=True)
# Export test of text metrics table
writeMetTestTableToFile()
# Create data for swathe plots
swatheData(paper_list + ["COMB"], save=True)
# Create correlation matrix data
corrMatData(paper_list + ["COMB"], metrics_list, save=True)
