#!/usr/bin/env python

"""
Script report_fcg_chunk
=======================
This script is used for parsing a FCG chunk. It creates eight outputs:
    - fcg_counts: all the counts regarding a chunk (number of mols, etc.)
    - fcg_nfcgpermol
    - fcg_topfrags
    - fcg_topfrags_u
    - fcg_nhits
    - fcg_nhits_u
    - fcg_fcc
    - fcg_fc

These outputs can then be concatenated with other chunks to create the report.

"""

# standard
import warnings
import logging
import argparse
import sys
from shutil import rmtree
from datetime import datetime
import re
from pathlib import Path
from collections import OrderedDict
# data handling
import numpy as np
import json
import pandas as pd
from pandas import DataFrame
import networkx as nx
# data visualization
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
# from pylab import savefig
from adjustText import adjust_text
# chemoinformatics
import rdkit
from rdkit import Chem
from rdkit.Chem import Mol
# docs
from typing import List
from typing import Tuple
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import utils
from npfc import load
from npfc import save
from npfc import report
from npfc import draw
from npfc import fragment_combination
from multiprocessing import Pool


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def get_dfs_fcc_from_df_fc(df_fc: DataFrame):
    """From a DataFrame with Fragment Combinations (edges), identify FP (ffs, ffo) and TP (the rest).

    It returns a dictionary with various counts:
        - df_fc: a DataFrame with counts of Fragment Combinations
        - df_fcc: a DataFrame with counts of Fragment Combination Categories
        - df_ffs: a DataFrame with counts of the number of ffs per molecule
        - df_ffo: a DataFrame with counts of the number of ffo per molecule
        - num_fc_ffs: the number of ffs Fragment Combinations
        - num_fc_ffo: the number of ffo Fragment Combinations
        - num_fc_tp: the number of true positives Fragment Combinations, i.e. that are not ffs or ffo
        - num_fc_tot: the total number of Fragment Combinations
        - num_mols_tot: the total number of molecules
        - num_mols_tp: the number of molecules with only true positives Fragment Combinations
        - num_mols_ffs: the number of molecules with at least 1 ffs Fragment Combination
        - num_mols_noffs: the number of molecules with 0 ffs Fragment Combinations
        - num_mols_ffo: the number of molecules with at least 1 ffo Fragment Combination
        - num_mols_noffo: the number of molecules with 0 ffo Fragment Combinations

    This function is used within iteratios over chunks in other functions, so counts have to be summed up.

    :param df_fc: a dataframe with fragment combinations
    :return: a dictionary with counts of Fragment Combinations.
    """
    # init
    categories = fragment_combination.get_fragment_combination_categories()

    # count
    num_fc_tot = len(df_fc.index)
    num_mols_tot = len(df_fc.groupby('idm'))

    # separate df into 3 parts: ffs, ffo and tp

    # ffs
    df_fc_ffs = df_fc[df_fc['fcc'] == 'ffs']
    num_fc_ffs = len(df_fc_ffs)
    if num_fc_ffs > 0:
        num_mols_ffs = len(df_fc_ffs.groupby('idm'))
        num_mols_noffs = len(df_fc[~df_fc['idm'].isin(df_fc_ffs['idm'])].groupby('idm'))
        df_fc_ffs_count = df_fc_ffs[['idm', 'idf1', 'idf2']].groupby('idm').count().rename({'idf1': 'NumSubstructures'}, axis=1).groupby('NumSubstructures').count().reset_index().rename({'idf2': 'Count'}, axis=1)
        df_fc_ffs_count = pd.concat([DataFrame({'NumSubstructures': [0], 'Count': [num_mols_noffs]}), df_fc_ffs_count]).reset_index(drop=True)
    else:
        num_mols_ffs = 0
        num_mols_noffs = num_mols_tot
        df_fc_ffs_count = DataFrame([[0, num_mols_noffs]], columns=['NumSubstructures', 'Count'])

    # ffo
    df_fc_ffo = df_fc[df_fc['fcc'] == 'ffo']
    num_fc_ffo = len(df_fc_ffo)
    if num_fc_ffo > 0:
        num_mols_ffo = len(df_fc_ffo.groupby('idm'))
        num_mols_noffo = len(df_fc[~df_fc['idm'].isin(df_fc_ffo['idm'])].groupby('idm'))
        df_fc_ffo_count = df_fc_ffo[['idm', 'idf1', 'idf2']].groupby('idm').count().rename({'idf1': 'NumOverlaps'}, axis=1).groupby('NumOverlaps').count().reset_index().rename({'idf2': 'Count'}, axis=1)
        df_fc_ffo_count = pd.concat([DataFrame({'NumOverlaps': [0], 'Count': [num_mols_noffo]}), df_fc_ffo_count]).reset_index(drop=True)
    else:
        num_mols_ffo = 0
        num_mols_noffo = num_mols_tot
        df_fc_ffo_count = DataFrame([[0, num_mols_noffo]], columns=['NumOverlaps', 'Count'])

    # tp
    df_fc = df_fc[~df_fc['fcc'].isin(['ffs', 'ffo'])]
    num_mols_tp = len(df_fc[(~df_fc['idm'].isin(df_fc_ffs['idm'])) & (~df_fc['idm'].isin(df_fc_ffo['idm']))].groupby('idm'))
    num_fc_tp = len(df_fc)

    # fcc
    df_fcc_count_default = pd.DataFrame({'fcc': categories, 'Count': [0] * len(categories)})
    df_fcc_count = df_fc[['fcc', 'idm']].groupby('fcc').count().rename({'idm': 'Count'}, axis=1).reset_index()
    df_fcc_count = pd.concat([df_fcc_count, df_fcc_count_default]).groupby('fcc').sum().T
    df_fcc_count = df_fcc_count[categories]
    df_fcc_count = df_fcc_count.T.reset_index().rename({'index': 'fcc'}, axis=1)

    # top fc
    df_fc_count = df_fc[['idf1', 'idf2', 'fcc', 'idm', 'smiles_frag_1', 'smiles_frag_2']].groupby(['idf1', 'idf2', 'fcc', 'smiles_frag_1', 'smiles_frag_2']).count().rename({'idm': 'Count'}, axis=1).reset_index()
    df_fc_count['fc'] = df_fc_count['idf1'] + '[' + df_fc_count['fcc'] + ']' + df_fc_count['idf2']
    df_fc_count = df_fc_count.drop(['idf1', 'idf2', 'fcc'], axis=1)

    return {'df_fc': df_fc_count,
            'df_fcc': df_fcc_count,
            'df_ffs': df_fc_ffs_count,
            'df_ffo': df_fc_ffo_count,
            'num_fc_ffs': num_fc_ffs,
            'num_fc_ffo': num_fc_ffo,
            'num_fc_tp': num_fc_tp,
            'num_fc_tot': num_fc_tot,
            'num_mols_tot': num_mols_tot,
            'num_mols_tp': num_mols_tp,
            'num_mols_ffs': num_mols_ffs,
            'num_mols_noffs': num_mols_noffs,
            'num_mols_ffo': num_mols_ffo,
            'num_mols_noffo': num_mols_noffo,
            }


def compute_mol_coverage_without_side_chains(mol, l_aidxs):
    rings = mol.GetRingInfo().AtomRings()
    rings = utils.fuse_rings(rings)
    ring_atoms = set([item for sublist in rings for item in sublist])

    linker_atoms = []
    for i in range(len(rings)):
        ring1 = rings[i]
        for j in range(i+1, len(rings)):
            ring2 = rings[j]
            # shortest path between the two rings that do not include the current rings themselves
            shortest_path = [x for x in Chem.GetShortestPath(mol, ring1[0], ring2[0]) if x not in ring1 + ring2]
            linker_atoms += [x for x in shortest_path if x not in list(ring_atoms) + linker_atoms]

    # define side chains as atoms not part of linkers and rings
    all_atoms = list(range(len(mol.GetAtoms())))
    side_chain_atoms = [x for x in all_atoms if x not in list(ring_atoms) + linker_atoms]

    # fragments
    frags = list(set(l_aidxs))

    # remove side chain from the equation
    frags = [x for x in frags if x not in side_chain_atoms]
    all_atoms = [x for x in all_atoms if x not in side_chain_atoms]
    return [len(all_atoms), len(frags), len(frags) / len(all_atoms)]


def parse_chunk_fcg(c):

    # retrieve data
    d = {}
    df_fcg = load.file(c, decode=['_fcg', '_d_mol_frags', '_frags']).sort_values(["idm", "nfrags"], ascending=True)  # df_fcg already sorted for the best examples per case
    num_tot_fcg_graph = len(df_fcg.index)

    if num_tot_fcg_graph == 0:
        d['num_tot_fcg_graph'] = 0
        d['num_tot_mol_graph'] = 0
        d['n_fcg_nhits_tot'] = 0
        d['n_fcg_nhits_u_tot'] = 0
        d['df_fcg_nfcgpermol'] = pd.DataFrame([], columns=['NumFCG', 'Count'])
        d['df_fcg_nhits'] = pd.DataFrame([], columns=['NumFrags', 'Count', 'Perc_Mols'])
        d['df_fcg_topfrags'] = pd.DataFrame([], columns=['idf', 'Count'])
        d['df_fcg_fragratio'] = pd.DataFrame([], columns=['idm', 'hac_mol', 'hac_frags', 'fragratio'])
        d['df_fcg_nhits_u'] = pd.DataFrame([], columns=['NumFrags', 'Count'])
        d['df_fcg_topfrags_u'] = pd.DataFrame([], columns=['idf', 'Count'])
        d['df_fcg_fcc'] = pd.DataFrame([], columns=['fcc', 'Count'])
        d['df_fcg_fc'] = pd.DataFrame([], columns=['smiles_frag_1', 'molfrag_2', 'Count', 'fc'])

        return d

    df_fcg = df_fcg.rename({'_frags': 'frags', '_frags_u': 'frags_u'}, axis=1)
    df_fcg['frags'] = df_fcg['frags'].map(list)
    df_fcg['frags_u'] = df_fcg['frags'].map(lambda x: list(set(x)))
    groups = df_fcg[['idm', 'idfcg', 'nfrags']].groupby('idm')
    num_tot_mol_graph = len(groups)

    # number of fragment graphs per molecule
    df_fcg_nfcgpermol = groups.count().rename({'idfcg': 'NumFCG'}, axis=1).groupby('NumFCG').count().rename({'nfrags': 'Count'}, axis=1).reset_index()

    # df_edges is used to define df_fcg_fcc, but it was plagued with duplicate combinations (common parts in alternate fcg)
    # this snippet regenerates df_edges, not from the graph (which does not contain the occurrence id of the fragments),
    # but from the fcg_str, which does.
    idm_groups = df_fcg.groupby('idm')
    rows = []
    cols = ['idm', 'idfcg', 'idf1', 'fcp_1', 'fcc', 'idf2', 'fcp_2']
    for gid, g in idm_groups:
        combinations = set('-'.join(g['fcg_str']).split('-'))
        for comb in combinations:
            f1 = comb.split(':')[0]
            f2 = comb.split(']')[1].split(':')[0]
            fcc = comb.split('[')[1].split(']')[0]
            fcp1 = comb.split('@')[1].split('[')[0]
            fcp2 = comb.split('@')[-1]
            rows.append([gid, -1, f1, fcp1, fcc, f2, fcp2])
    df_edges = pd.DataFrame(rows, columns=cols)

    # fs analysis

    # initialization of a common df useful for fs analysis
    df_fcg['aidxs'] = df_fcg['_d_aidxs'].map(lambda x: [v for l in x.values() for v in l])  # extract all values from dict: list of tuples
    df_fcg['aidxs'] = df_fcg['aidxs'].map(lambda x: [v for l in x for v in l])   # flatten the list and identify only unique atom indices

    groups = df_fcg.groupby('idm')
    df_fcg_grouped = groups.agg({'frags': 'sum'}).reset_index(drop=True)  # concatenate lists in the same group
    df_fcg_grouped['frags'] = df_fcg_grouped['frags'].map(lambda x: list(set(x)))  # count each occurrence of a fragment
    df_fcg_grouped['n_frags'] = df_fcg_grouped['frags'].map(lambda x: len(x))

    # fragment hits per mol
    df_fcg_nhits = df_fcg_grouped[['n_frags', 'frags']].groupby('n_frags').count().reset_index().rename({'frags': 'Count', 'n_frags': 'NumFrags'}, axis=1)
    df_fcg_nhits['Perc_Mols'] = df_fcg_nhits['Count'].map(lambda x: f"{x / num_tot_mol_graph:.2%}")

    # top fragments

    # process fm data
    #df_fcg['frags'] = df_fcg['_frags'].map(lambda x: eval(''.join(x)))  # NOT IDEAL BUT REQUIRED AS DECODE=FALSE WHEN LOADING
    df_fcg_topfrags = df_fcg_grouped['frags'].apply(pd.Series).reset_index().melt(id_vars='index').dropna()[['index', 'value']]  # ungroup values by frag id in list

    df_fcg_topfrags['value'] = df_fcg_topfrags['value'].map(lambda x: x.split(':')[0])
    df_fcg_topfrags = df_fcg_topfrags.groupby('value').count().reset_index().rename({'value': 'idf', 'index': 'Count'}, axis=1).sort_values('Count', ascending=False).reset_index(drop=True)  # count and sort idfs
    n_fcg_nhits_tot = df_fcg_topfrags['Count'].sum()

    # fragment ratio per molecule
    df_fcg_fragratio = groups.agg({'aidxs': 'sum', 'hac_mol': 'first', 'mol': 'first'}).reset_index()  # concatenate all aidxs obtained previously
    df_fcg_fragratio['hac_frags'] = df_fcg_fragratio['aidxs'].map(lambda x: len(set(x)))  # the length of atom indices is the number of hac in fragments
    df_fcg_fragratio['fragratio'] = df_fcg_fragratio['hac_frags'] / df_fcg_fragratio['hac_mol']

    # add fragment ratio without side chains
    df_fcg_fragratio['tmp'] = df_fcg_fragratio.apply(lambda x: compute_mol_coverage_without_side_chains(x['mol'], x['aidxs']), axis=1)
    df_fcg_fragratio['hac_mol_wo_side_chain'] = df_fcg_fragratio['tmp'].map(lambda x: x[0])
    df_fcg_fragratio['hac_frag_wo_side_chain'] = df_fcg_fragratio['tmp'].map(lambda x: x[1])
    df_fcg_fragratio['fragratio_wo_side_chain'] = df_fcg_fragratio['tmp'].map(lambda x: x[2])
    df_fcg_fragratio.drop('tmp', axis=1, inplace=True)

    df_fcg_fragratio.drop(['aidxs', 'mol'], axis=1, inplace=True)

    # unique fragment hits per mol
    df_fcg_grouped['frags_u'] = df_fcg_grouped['frags'].map(lambda x: list(set([v.split(':')[0] for v in x])))
    df_fcg_grouped['n_frags_u'] = df_fcg_grouped['frags_u'].map(lambda x: len(x))
    df_fcg_nhits_u = df_fcg_grouped[['n_frags_u', 'frags_u']].groupby('n_frags_u').count().reset_index().rename({'frags_u': 'Count', 'n_frags_u': 'NumFrags'}, axis=1)
    n_fcg_nhits_u_tot = df_fcg_nhits_u['Count'].sum()

    # top unique fragments
    df_fcg_topfrags_u = df_fcg_grouped['frags_u'].apply(pd.Series).reset_index().melt(id_vars='index').dropna()[['index', 'value']]  # ungroup values by frag id in list
    df_fcg_topfrags_u = df_fcg_topfrags_u.groupby('value').count().reset_index().rename({'value': 'idf', 'index': 'Count'}, axis=1).sort_values('Count', ascending=False).reset_index(drop=True)  # count and sort idfs

    # top fragment combinations
    # compute smiles
    ds_frags = list(df_fcg['_d_mol_frags'].map(lambda x: {str(k): Chem.MolToSmiles(v) for k, v in x.items()}).values)
    # gather smiles in a dict
    d_frags = {}
    [d_frags.update(x) for x in ds_frags]
    # add smiles columns for top combination
    df_edges['smiles_frag_1'] = df_edges['idf1'].map(lambda x: d_frags[x])
    df_edges['smiles_frag_2'] = df_edges['idf2'].map(lambda x: d_frags[x])

    # fragment combination categories
    d_tmp = get_dfs_fcc_from_df_fc(df_edges)

    # record results
    d['num_tot_fcg_graph'] = num_tot_fcg_graph
    d['num_tot_mol_graph'] = num_tot_mol_graph
    d['n_fcg_nhits_tot'] = n_fcg_nhits_tot
    d['n_fcg_nhits_u_tot'] = n_fcg_nhits_u_tot
    d['df_fcg_nfcgpermol'] = df_fcg_nfcgpermol
    d['df_fcg_nhits'] = df_fcg_nhits
    d['df_fcg_topfrags'] = df_fcg_topfrags
    d['df_fcg_fragratio'] = df_fcg_fragratio
    d['df_fcg_nhits_u'] = df_fcg_nhits_u
    d['df_fcg_topfrags_u'] = df_fcg_topfrags_u
    d['df_fcg_fcc'] = d_tmp['df_fcc']
    d['df_fcg_fc'] = d_tmp['df_fc']

    return d


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script for extracting data from FCG chunks.
    The output files can then be concatenated with other chunks for reporting.

    This script uses the installed npfc libary in your favorite env manager.

    Example:

        >>> $ input_file=fc/04_synthetic/chembl/data/prep/natref_dnp/frags_crms/09_fcg/chembl_001_fcg.csv.gz
        >>> $ output_dir=fc/04_synthetic/chembl/data/prep/natref_dnp/frags_crms/09_fcg/report/data
        >>> $ report_fcg_chunk $input_file $output_dir

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description="Script for extracting data from FCG chunks. The output files can then be concatenated with other chunks for reporting.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_file', type=str, default=None, help="Input file (i.e. fc/04_synthetic/chembl/data/prep/natref_dnp/frags_crms/09_fcg/data/chembl_001_fcg.csv.gz)")
    parser.add_argument('output_dir', type=str, default=None, help="Output directory")
    parser.add_argument('-s', '--suffix', type=str, default=None, help="Suffix to append for output files. By default, None so the complete suffix is '_fcg'. Useful to differenciate between PNPs ('fcg_pnp') and NPLs ('fcg_npl').")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # check arguments

    # I/O
    utils.check_arg_input_file(args.input_file)
    utils.check_arg_output_dir(args.output_dir)
    output_dir = Path(args.output_dir)

    # prefix
    prefix = Path(args.input_file).stem
    while '.' in prefix:
        prefix = Path(prefix).stem
    # output files
    output_counts = f"{output_dir}/{prefix}_counts.csv"
    output_nfcgpermol = f"{output_dir}/{prefix}_nfcgpermol.csv"
    output_topfrags = f"{output_dir}/{prefix}_topfrags.csv"
    output_topfrags_u = f"{output_dir}/{prefix}_topfrags_u.csv"
    output_nhits = f"{output_dir}/{prefix}_nhits.csv"
    output_nhits_u = f"{output_dir}/{prefix}_nhits_u.csv"
    output_fcc = f"{output_dir}/{prefix}_fcc.csv"
    output_fc = f"{output_dir}/{prefix}_fc.csv"
    output_fragratio = f"{output_dir}/{prefix}_fragratio.csv"


    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INIT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    # logging
    logger = utils._configure_logger(args.log)
    pd.options.mode.chained_assignment = None  # disable pd.io.pytables.SettingWithCopyWarning
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.INFO)
    pad = 40

    # display infos
    logger.info("LIBRARY VERSIONS")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS")
    logger.info("input_file".ljust(pad) + f"{args.input_file}")
    logger.info("WD_out".ljust(pad) + f"{args.output_dir}")
    logger.info("suffix".ljust(pad) + f"{args.suffix}")
    logger.info("OUTPUT FILES")
    logger.info("output_counts".ljust(pad) + f"{output_counts}")
    logger.info("output_nfcgpermol".ljust(pad) + f"{output_nfcgpermol}")
    logger.info("output_topfrags".ljust(pad) + f"{output_topfrags}")
    logger.info("output_topfrags_u".ljust(pad) + f"{output_topfrags_u}")
    logger.info("output_nhits".ljust(pad) + f"{output_nhits}")
    logger.info("output_nhits_u".ljust(pad) + f"{output_nhits_u}")
    logger.info("output_fcc".ljust(pad) + f"{output_fcc}")
    logger.info("output_fc".ljust(pad) + f"{output_fc}")
    logger.info("output_fragratio".ljust(pad) + f"{output_fragratio}")
    logger.info("log".ljust(pad) + f"{args.log}")
    d0 = datetime.now()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


    # parse FCG chunk
    logger.info("PARSING OF THE FCG CHUNK...")
    results = parse_chunk_fcg(args.input_file)
    d1 = datetime.now()

    # nfcgpermol
    save.file(results['df_fcg_nfcgpermol'], output_nfcgpermol)
    # topfrags
    save.file(results['df_fcg_topfrags'], output_topfrags)
    # topfrags_u
    save.file(results['df_fcg_topfrags_u'], output_topfrags_u)
    # nhits
    save.file(results['df_fcg_nhits'], output_nhits)
    # nhits_u
    save.file(results['df_fcg_nhits_u'], output_nhits_u)
    # fcc
    save.file(results['df_fcg_fcc'], output_fcc)
    # fc
    save.file(results['df_fcg_fc'], output_fc)
    # fragratio
    save.file(results['df_fcg_fragratio'], output_fragratio)
    # counts, should be written last for coherency with smk pipeline
    df_counts = pd.DataFrame([[results['num_tot_fcg_graph'],
                               results['num_tot_mol_graph'],
                               results['n_fcg_nhits_tot'],
                               results['n_fcg_nhits_u_tot'],
                               ]], columns=['num_tot_fcg_graph',
                                            'num_tot_mol_graph',
                                            'n_fcg_nhits_tot',
                                            'n_fcg_nhits_u_tot',
                                            ])
    save.file(df_counts, output_counts)

    d2 = datetime.now()

    # end
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: PARSING CHUNK".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: SAVING RESULTS".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d2-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
