#!/usr/bin/env python3
# Authors:   Nurzhan Sapargali <sapargalin95@gmail.com>
#            Michael E. Rose <michael.ernst.rose@gmail.com>
"""Creates co-author networks for five separate decades using Scopus data
and eventually the coverage of Econlit.
"""

from itertools import combinations, product
from collections import defaultdict, namedtuple

import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
from matplotlib_venn import venn2
from num2words import num2words
from pybliometrics.scopus import AbstractRetrieval, ScopusSearch
from tqdm import tqdm

from _100_create_EconLit_networks import write_stats

MAPPING_FOLDER = './020_mapping/'
COVERAGE_FOLDER = './050_coverage/'
TARGET_FOLDER = './200_Scopus_networks/'
OUTPUT_FOLDER = './990_output/'

ECONLIT_ONLY = True  # Whether to align Scopus network with EconLit coverage
DECADES = [range(1970, 1980), range(1980, 1990), range(1990, 2000),
           range(2000, 2010), range(2010, 2020)]
TYPES = {"ar", "cp", "re", "no", "sh"}


def make_coverage_venn(df, fname):
    """Make 2 set Venn diagramme to show coverage overlap."""
    # Split data
    scopus = set(df[df["Scopus"] == 1].index)
    econlit = set(df[~df["Econlit"].isna()].index)
    # Create plot
    plt.figure(figsize=(4, 4))
    v = venn2([scopus, econlit], ('Scopus', 'EconLit'))
    v.get_patch_by_id('01').set_color("#a99c73")
    v.get_patch_by_id('11').set_color("#655d45")
    v.get_patch_by_id('10').set_color('#9c73a9')
    for patch_id in ('01', '11', '10'):
        v.get_patch_by_id(patch_id).set_alpha(1)
    # Save figure
    plt.savefig(fname, bbox_inches="tight")
    plt.close()


def make_pnas_dict():
    """"Create dictionary mapping PNAS articles to their EIDs."""
    pnas = pd.read_csv(COVERAGE_FOLDER + "EconLit_PNAS.csv")
    pnas = {c: pnas[c].dropna().tolist() for c in pnas.columns}
    fname = MAPPING_FOLDER + "PNAS_Scopus.csv"
    pnas_eids = pd.read_csv(fname, index_col=0)["eid"].to_dict()
    return {k: [pnas_eids[art] for art in v] for k, v in pnas.items()}


def robust_query(q, refresh=200, fields=["eid", "coverDate"]):
    """Download a ScopusSearch result set and try once if it fails."""
    try:
        res = ScopusSearch(q, integrity_fields=fields, refresh=refresh).results
    except (AttributeError, TypeError):
        try:
            res = ScopusSearch(q, integrity_fields=fields, refresh=True).results
        except AttributeError:
            res = ScopusSearch(q).results
    return res or []


def main():
    # Read Scopus Source IDs
    scopus = pd.read_csv(MAPPING_FOLDER + "Scopus.csv", dtype="object")
    id_map = scopus.set_index("Journal")["source_id"].to_dict()
    name_map = scopus.dropna().set_index("former")["Journal"].to_dict()

    # Get relevant sources
    if ECONLIT_ONLY:
        econlit = pd.read_csv(COVERAGE_FOLDER + "econlit_actual.csv")
        econlit = econlit.pivot_table(index="journal", columns="year",
                                      values="pub_count")
        sources = econlit.index.to_list()
        df = (econlit.reset_index()
                     .melt(id_vars="journal", value_name="coverage")
                     .dropna(subset=["coverage"]))
        df["year"] = df["year"].astype("uint")
        coverage = df[["journal", "year"]].dropna().values.tolist()
    else:
        sources = scopus["Journal"].unique()
        name_map = {}
        coverage = []

    # Read EIDs of PNAS articles
    pnas = make_pnas_dict()
    PNAS_item = namedtuple("PNAS", "subtype author_ids author_count")

    # Build networks by decade
    stats = {}
    scopus_coverage = set()
    print(">>> Now working on:")
    for decade in DECADES:
        year = str(decade[0])
        dec_name = num2words(year[-2:])
        print(f"... {year}s:")
        # Get publications
        pubs = []
        n_volumes = 0
        combs = list(product(decade, sources))
        for y, source1 in tqdm(combs):
            # Get publication list
            source_id = id_map.get(source1)
            if not source_id:
                continue
            query = f"SOURCE-ID({source_id}) AND PUBYEAR IS {y}"
            res = robust_query(query, refresh=150)
            source2 = source1
            while source2 in name_map:
                source_id = id_map.get(name_map[source2])
                query = f"SOURCE-ID({source_id}) AND PUBYEAR IS {y}"
                res.extend(robust_query(query, refresh=150))
                source2 = name_map[source2]
            if not res:
                continue
            # Store information
            n_volumes += 1
            scopus_coverage.add((source1, y))
            if ECONLIT_ONLY and [source1, y] not in coverage:
                continue
            pubs.extend(res)
        # Add PNAS articles
        for eid in pnas[year]:
            ab = AbstractRetrieval(eid, view="FULL")
            new = PNAS_item("ar", ";".join([au.auid for au in ab.authors]),
                            len(ab.authors))
            pubs.append(new)
        # Build network
        pubs = [p for p in pubs if p.author_ids and p.subtype in TYPES]
        print(f"  - {len(pubs):,} publications from {n_volumes:,} volumes")
        stats[f"Scopus_N_of_pubs_{dec_name}s"] = len(pubs)
        stats[f"Scopus_N_of_volumes_{dec_name}s"] = n_volumes
        # Compute meta information
        pub_count = defaultdict(int)
        pubco_count = defaultdict(int)
        auth_groups = []
        for p in pubs:
            authors = p.author_ids.split(';')
            auth_groups.append(authors)
            for auth in authors:
                pub_count[auth] += 1
                if int(p.author_count) > 1:
                    pubco_count[auth] += 1
        # Generate network edges
        print(f"  - {len(pub_count):,} distinct authors")
        stats[f"Scopus_N_of_authors_{dec_name}s"] = len(pub_count)
        combs = [combinations(i, 2) for i in auth_groups]
        edges = [i for j in combs for i in j]
        # Populate network
        G = nx.Graph(name=year)
        G.add_nodes_from(pub_count.keys())
        G.add_edges_from(edges)
        nx.set_node_attributes(G, pub_count, "Number of publications")
        nx.set_node_attributes(G, pubco_count, "Number of co-authored publications")
        # Write out
        fname = f"{TARGET_FOLDER}/{year}.gexf"
        nx.write_gexf(G, fname)

    # Statistics
    write_stats(stats)

    # Coverage graph
    if ECONLIT_ONLY:
        new = pd.DataFrame.from_records(list(scopus_coverage),
                                        columns=["journal", "year"])
        new["Scopus"] = 1
        df = (df.merge(new, "outer", on=["journal", "year"], indicator=True)
                .sort_values(["journal", "year"])
                .rename(columns={"coverage": "Econlit"}))
        fname = OUTPUT_FOLDER + "Figures/coverage_old.pdf"
        make_coverage_venn(df[df["year"] < 2000], fname)
        fname = OUTPUT_FOLDER + "Figures/coverage_new.pdf"
        make_coverage_venn(df[df["year"] > 1999], fname)


if __name__ == '__main__':
    main()
