#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

This script creates the term frequency matrix using dictionaries and range of ngrams
of our choice then is saved as output in IntermediateDataOutputs file
"""
import pandas as pd
import os
import json
from sklearn.feature_extraction.text import CountVectorizer
import configparser
import glob
import src.txtutils as txtutil
import src.dictionaryclass as rfdictcls
import importlib
from loguru import logger

importlib.reload(txtutil)
# ---------------------------------------------------------------------------
# Settings
# ---------------------------------------------------------------------------
config = configparser.ConfigParser()
config.read("config.ini")
my_stop_words = txtutil.getStopwords()
refDictsInstance = rfdictcls.RefDictClass()


# ---------------------------------------------------------------------------
# Functions
# ---------------------------------------------------------------------------
def __fn_tdm_df(docs, **kwargs):
    """internal function that creates the term-frequency matrix
    The reason for this is to pass arguments of our choice with **kwargs
    each time
    """
    # initialize the  vectorizer
    vectorizer = CountVectorizer(stop_words=my_stop_words, **kwargs)
    x1 = vectorizer.fit_transform(docs)
    return pd.DataFrame(
        x1.toarray(), columns=vectorizer.get_feature_names_out(), index=docs.index
    )


def tfmatrix(dfColIn, newspaper, save=False):
    """This functions creates a tf matrix and saves it in a csv file"""
    # insert the dictionaries (optional) or define range of ngrams
    # we are inrterested in - takes the union
    vocab = list(
        set(
            list(refDictsInstance.unique_dict_neg.keys())
            + list(refDictsInstance.unique_dict_pos.keys())
            + list(refDictsInstance.union_dictionary.keys())
        )
    )
    logger.info("Vocab assembled; running tf")
    tf_matrix = __fn_tdm_df(dfColIn, vocabulary=vocab, ngram_range=(1, 3))
    logger.info("tf completed \n ----------------------------------")
    o_path = os.path.join(
        config["data"]["intermed"], newspaper + config["fEnds"]["tf_pa"]
    )
    if save:
        tf_matrix.to_csv(o_path, date_format="%d %m %Y", encoding="utf8")
    return tf_matrix


def tfmatrix_agg(tf_matrix, newspaper):
    tf_matrix = tf_matrix.groupby(pd.Grouper(freq="M")).mean()
    o_path = os.path.join(
        config["data"]["intermed"], newspaper + config["fEnds"]["tf_m"]
    )
    tf_matrix.to_csv(o_path, date_format="%d %m %Y", encoding="utf8")


def tf_matrices_of_newspaper_full_text():
    paper_list = json.loads(config.get("papers", "paper_list"))
    if "TEST" not in paper_list:
        paper_list = ["TEST"] + paper_list
    for newspaper in paper_list:
        # Set to none for all:
        nrows = None
        logger.info("Running " + newspaper)
        if not glob.glob(f"data/intermed/{newspaper}tf*.csv"):
            print(f"processing {newspaper}")
            df = txtutil.getCleanTextDf(newspaper, nrows)
            tf_matrix = tfmatrix(df["cleanText"], newspaper, save=False)
            tfmatrix_agg(tf_matrix, newspaper)
        else:
            print(f"{newspaper=} has already been processed")
