#!/usr/bin/env python3
# Author(s):  Nurzhan Sapargali <sapargalin95@gmail.com>
#             Michael E. Rose <Michael.Ernst.Rose@gmail.com>
"""Outputs network statistics for most linked economists."""

from glob import glob
from os.path import basename, splitext

import networkx as nx
import pandas as pd
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from pybliometrics.scopus import AuthorRetrieval

ECONLIT_FOLDER = './100_EconLit_networks/'
SCOPUS_FOLDER = './200_Scopus_networks/'
OUTPUT_FOLDER = './990_output/'

_grouping = {"1970s": 14, "1980s": 17, "1990s": 22, "2000s": 37, "2010s": 65}


def get_author_names(auth_ids):
    """Obtain names from Scopus Author profiles."""
    names = []
    for au_id in auth_ids:
        au = AuthorRetrieval(au_id)
        names.append(", ".join((au.surname, au.given_name)))
    return names


def make_figure1(df, fname, size=(5, 5)):
    """"Create and write log-log scatterplot of degree to clustering."""
    # Parameters
    markers = {"1970s": ".", "1980s": "^", "1990s": "X", "2000s": 'o',
               "2010s": "s"}
    colors = {"1970s": "blue", "1980s": "orange", "1990s": "green",
              "2000s": 'purple', "2010s": "red"}
    xticks = [1, 10, 100]
    yticks = [0.01, 0.02, 0.03, 0.04, 0.06, 0.1, 0.2, 0.3, 0.4, 0.6, 1]
    # Plot
    fig, ax = plt.subplots(figsize=size)
    sns.scatterplot(ax=ax, x="degree", y="clustering", data=df, style="year",
                    markers=markers, hue="year", palette=colors)
    ax.plot([1, 100], [1, 0.01], linestyle="-.", color="black", linewidth=1)
    # Remove legend title
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles[0:], labels=labels[0:])
    # Aesthetics
    ax.set(xscale="log", xlim=(xticks[0], xticks[-1]),
           yscale="log", ylim=(yticks[0], yticks[-1]))
    ax.get_xaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.get_yaxis().set_major_formatter(mpl.ticker.ScalarFormatter())
    ax.set(xticks=xticks, yticks=yticks)
    # Save
    plt.savefig(fname, bbox_inches="tight")
    plt.close()


def make_table2(fname, G, degree, clustering):
    """Create and write table using top 5 and various averages."""
    # Gather information
    degree2 = {i: num_sec_neigh(i, G) for i in G.nodes}
    publications = nx.get_node_attributes(G, "Number of publications")
    coauthors = nx.get_node_attributes(G, "Number of co-authored publications")
    data = pd.DataFrame({'Papers': publications,
                         'coauthored': coauthors,
                         'Degree': degree,
                         'Distance 2': degree2,
                         'Clustering Coefficient': clustering})
    data = (data.sort_values(by=['Degree', 'Distance 2'], ascending=False)
                .reset_index())
    data.index += 1
    # Print most connected
    table2 = data.head(5).copy()
    most_connected = table2['index'].values
    if most_connected[0].isdigit():
        most_connected = get_author_names(most_connected)
    print(f">>> 5 most connected Economists: {'; '.join(most_connected)}")
    # Generate table
    table2 = table2.drop("index", axis=1)
    table2["coauthored"] = table2["coauthored"]/table2["Papers"]
    table2 = table2.rename(columns={"coauthored": "% Coauthored"})
    # Add averages
    data = data.drop("index", axis=1)
    top100 = data.head(100)
    table2.loc['Average top 100'] = list(top100.mean())
    table2.loc['Average all'] = list(data.mean())
    # Correct some averages
    coauth_100 = top100["coauthored"].sum()/top100["Papers"].sum()
    table2.loc['Average top 100', '% Coauthored'] = coauth_100
    coauth_all = data["coauthored"].sum()/data["Papers"].sum()
    table2.loc['Average all', '% Coauthored'] = coauth_all
    table2.loc['Average all', 'Clustering Coefficient'] = nx.transitivity(G)
    # Write out
    table2['% Coauthored'] = (table2['% Coauthored']*100).round(1)
    table2.to_latex(fname, float_format=lambda x: f'{x:,.3f}')


def num_sec_neigh(node, G):
    """Return number of unique second-order neighbors."""
    neigh_sec_order = nx.single_source_shortest_path_length(G, node, cutoff=2)
    return sum(1 for x in neigh_sec_order.values() if x == 2)


def main():
    print(">>> Now working on:")
    for label, folder in (("Econ", ECONLIT_FOLDER), ("Scopus", SCOPUS_FOLDER)):
        print("...", label, "...")
        df = pd.DataFrame()
        for f in glob(folder + "*gexf"):
            G = nx.read_gexf(f)
            year = splitext(basename(f))[0] + "s"
            print("...", year)

            # Compute centralities
            degs = dict(G.degree)
            clust = dict(nx.clustering(G))
            new = pd.DataFrame({'degree': degs,
                                'clustering': clust,
                                'year': year})
            df = pd.concat([df, new], sort=False)

            # Create table 2 (most connected nodes)
            if year == "1990s":
                fname = f"{OUTPUT_FOLDER}/Tables/{label}_2.tex"
                make_table2(fname, G, degs, clust)

        # Prepare for scatter of degree versus clustering
        df = df[df["degree"] >= 2]
        for year, deg_cutoff in _grouping.items():
            mask = (df["year"] == year) & (df["degree"] >= deg_cutoff)
            df.loc[mask, "degree"] = deg_cutoff
        clust = df.groupby(["year", "degree"])["clustering"].mean().reset_index()
        # Plot
        mask_wide = clust["year"].isin(("2000s", "2010s"))
        fname = f"{OUTPUT_FOLDER}/Figures/{label}_1.pdf"
        make_figure1(clust[~mask_wide], fname)
        fname = f"{OUTPUT_FOLDER}/Figures/{label}_1_long.pdf"
        make_figure1(clust[mask_wide], fname)


if __name__ == '__main__':
    main()
