#!/usr/bin/env python

"""
Script report_pnp
==========================
This script is used for counting PNPs and non-PNPs from data chunks.
"""

# standard
import warnings
import logging
import argparse
import sys
from shutil import rmtree
from datetime import datetime
import re
from pathlib import Path
from collections import OrderedDict
# data handling
import numpy as np
import json
import pandas as pd
from pandas import DataFrame
import networkx as nx
# data visualization
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
# from pylab import savefig
from adjustText import adjust_text
# chemoinformatics
import rdkit
from rdkit import Chem
from rdkit.Chem import Mol
# docs
from typing import List
from typing import Tuple
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import utils
from npfc import load
from npfc import save
from npfc import report
from npfc import draw
from npfc import fragment_combination
from multiprocessing import Pool

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def get_dfs_fcc_from_df_fc(df_fc: DataFrame):
    """From a DataFrame with Fragment Combinations (edges), identify FP (ffs, ffo) and TP (the rest).

    It returns a dictionary with various counts:
        - df_fc: a DataFrame with counts of Fragment Combinations
        - df_fcc: a DataFrame with counts of Fragment Combination Categories
        - df_ffs: a DataFrame with counts of the number of ffs per molecule
        - df_ffo: a DataFrame with counts of the number of ffo per molecule
        - num_fc_ffs: the number of ffs Fragment Combinations
        - num_fc_ffo: the number of ffo Fragment Combinations
        - num_fc_tp: the number of true positives Fragment Combinations, i.e. that are not ffs or ffo
        - num_fc_tot: the total number of Fragment Combinations
        - num_mols_tot: the total number of molecules
        - num_mols_tp: the number of molecules with only true positives Fragment Combinations
        - num_mols_ffs: the number of molecules with at least 1 ffs Fragment Combination
        - num_mols_noffs: the number of molecules with 0 ffs Fragment Combinations
        - num_mols_ffo: the number of molecules with at least 1 ffo Fragment Combination
        - num_mols_noffo: the number of molecules with 0 ffo Fragment Combinations

    This function is used within iteratios over chunks in other functions, so counts have to be summed up.

    :param df_fc: a dataframe with fragment combinations
    :return: a dictionary with counts of Fragment Combinations.
    """
    # init
    categories = fragment_combination.get_fragment_combination_categories()

    # count
    num_fc_tot = len(df_fc.index)
    num_mols_tot = len(df_fc.groupby('idm'))

    # separate df into 3 parts: ffs, ffo and tp

    # ffs
    df_fc_ffs = df_fc[df_fc['fcc'] == 'ffs']
    num_fc_ffs = len(df_fc_ffs)
    if num_fc_ffs > 0:
        num_mols_ffs = len(df_fc_ffs.groupby('idm'))
        num_mols_noffs = len(df_fc[~df_fc['idm'].isin(df_fc_ffs['idm'])].groupby('idm'))
        df_fc_ffs_count = df_fc_ffs[['idm', 'idf1', 'idf2']].groupby('idm').count().rename({'idf1': 'NumSubstructures'}, axis=1).groupby('NumSubstructures').count().reset_index().rename({'idf2': 'Count'}, axis=1)
        df_fc_ffs_count = pd.concat([DataFrame({'NumSubstructures': [0], 'Count': [num_mols_noffs]}), df_fc_ffs_count]).reset_index(drop=True)
    else:
        num_mols_ffs = 0
        num_mols_noffs = num_mols_tot
        df_fc_ffs_count = DataFrame([[0, num_mols_noffs]], columns=['NumSubstructures', 'Count'])

    # ffo
    df_fc_ffo = df_fc[df_fc['fcc'] == 'ffo']
    num_fc_ffo = len(df_fc_ffo)
    if num_fc_ffo > 0:
        num_mols_ffo = len(df_fc_ffo.groupby('idm'))
        num_mols_noffo = len(df_fc[~df_fc['idm'].isin(df_fc_ffo['idm'])].groupby('idm'))
        df_fc_ffo_count = df_fc_ffo[['idm', 'idf1', 'idf2']].groupby('idm').count().rename({'idf1': 'NumOverlaps'}, axis=1).groupby('NumOverlaps').count().reset_index().rename({'idf2': 'Count'}, axis=1)
        df_fc_ffo_count = pd.concat([DataFrame({'NumOverlaps': [0], 'Count': [num_mols_noffo]}), df_fc_ffo_count]).reset_index(drop=True)
    else:
        num_mols_ffo = 0
        num_mols_noffo = num_mols_tot
        df_fc_ffo_count = DataFrame([[0, num_mols_noffo]], columns=['NumOverlaps', 'Count'])

    # tp
    df_fc = df_fc[~df_fc['fcc'].isin(['ffs', 'ffo'])]
    num_mols_tp = len(df_fc[(~df_fc['idm'].isin(df_fc_ffs['idm'])) & (~df_fc['idm'].isin(df_fc_ffo['idm']))].groupby('idm'))
    num_fc_tp = len(df_fc)

    # fcc
    df_fcc_count_default = pd.DataFrame({'fcc': categories, 'Count': [0] * len(categories)})
    df_fcc_count = df_fc[['fcc', 'idm']].groupby('fcc').count().rename({'idm': 'Count'}, axis=1).reset_index()
    df_fcc_count = pd.concat([df_fcc_count, df_fcc_count_default]).groupby('fcc').sum().T
    df_fcc_count = df_fcc_count[categories]
    df_fcc_count = df_fcc_count.T.reset_index().rename({'index': 'fcc'}, axis=1)

    # top fc
    df_fc_count = df_fc[['idf1', 'idf2', 'fcc', 'idm', 'mol_frag_1', 'mol_frag_2']].groupby(['idf1', 'idf2', 'fcc', 'mol_frag_1', 'mol_frag_2']).count().rename({'idm': 'Count'}, axis=1).reset_index()
    df_fc_count['fc'] = df_fc_count['idf1'] + '[' + df_fc_count['fcc'] + ']' + df_fc_count['idf2']
    df_fc_count = df_fc_count.drop(['idf1', 'idf2', 'fcc'], axis=1)

    return {'df_fc': df_fc_count,
            'df_fcc': df_fcc_count,
            'df_ffs': df_fc_ffs_count,
            'df_ffo': df_fc_ffo_count,
            'num_fc_ffs': num_fc_ffs,
            'num_fc_ffo': num_fc_ffo,
            'num_fc_tp': num_fc_tp,
            'num_fc_tot': num_fc_tot,
            'num_mols_tot': num_mols_tot,
            'num_mols_tp': num_mols_tp,
            'num_mols_ffs': num_mols_ffs,
            'num_mols_noffs': num_mols_noffs,
            'num_mols_ffo': num_mols_ffo,
            'num_mols_noffo': num_mols_noffo,
            }


def print_header(pad=160):
    """
    """

    logger.info("RUNNING REPORT_PROTOCOL")

    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("seaborn".ljust(pad) + f"{sns.__version__}")

    logger.info("ARGUMENTS:")
    logger.info("WD".ljust(pad) + f"{args.wd}")
    logger.info("NATREF".ljust(pad) + f"{args.natref}")
    logger.info("FRAGS".ljust(pad) + f"{args.frags}")
    logger.info("DATASET".ljust(pad) + f"{dataset}")
    if args.prefix is None:
        logger.info("PREFIX".ljust(pad) + f"{prefix} (default)")
    else:
        logger.info("PREFIX".ljust(pad) + f"{prefix}")
    if args.color in d_colors.keys():
        logger.info("COLOR".ljust(pad) + f"{color} ('{args.color}')")
    else:
        logger.info("COLOR".ljust(pad) + f"{color}")
    logger.info("PLOT FORMAT".ljust(pad) + f"{args.plotformat}")
    logger.info("COMPUTE CSV ONLY".ljust(pad) + f"{args.csv}")


def compute_mol_coverage_without_side_chains(mol, l_aidxs):
    rings = mol.GetRingInfo().AtomRings()
    rings = utils.fuse_rings(rings)
    ring_atoms = set([item for sublist in rings for item in sublist])

    linker_atoms = []
    for i in range(len(rings)):
        ring1 = rings[i]
        for j in range(i+1, len(rings)):
            ring2 = rings[j]
            # shortest path between the two rings that do not include the current rings themselves
            shortest_path = [x for x in Chem.GetShortestPath(mol, ring1[0], ring2[0]) if x not in ring1 + ring2]
            linker_atoms += [x for x in shortest_path if x not in list(ring_atoms) + linker_atoms]

    # define side chains as atoms not part of linkers and rings
    all_atoms = list(range(len(mol.GetAtoms())))
    side_chain_atoms = [x for x in all_atoms if x not in list(ring_atoms) + linker_atoms]

    # fragments
    frags = list(set(l_aidxs))

    # remove side chain from the equation
    frags = [x for x in frags if x not in side_chain_atoms]
    all_atoms = [x for x in all_atoms if x not in side_chain_atoms]
    return [len(all_atoms), len(frags), len(frags) / len(all_atoms)]


def parse_chunk_fcg(c):

    # retrieve data
    d = {}
    df_fcg = load.file(c, decode=['_fcg', '_d_mol_frags', '_frags']).sort_values(["idm", "nfrags"], ascending=True)  # df_fcg already sorted for the best examples per case
    num_tot_fcg_graph = len(df_fcg.index)

    if num_tot_fcg_graph == 0:
        d['num_tot_fcg_graph'] = 0
        d['num_tot_mol_graph'] = 0
        d['n_fcg_nhits_tot'] = 0
        d['n_fcg_nhits_u_tot'] = 0
        d['df_fcg_nfcgpermol'] = pd.DataFrame([], columns=['NumFCG', 'Count'])
        d['df_fcg_nhits'] = pd.DataFrame([], columns=['NumFrags', 'Count', 'Perc_Mols'])
        d['df_fcg_top_frags'] = pd.DataFrame([], columns=['idf', 'Count'])
        d['df_fcg_frag_ratio'] = pd.DataFrame([], columns=['idm', 'hac_mol', 'hac_frags', 'frag_ratio'])
        d['df_fcg_nhits_u'] = pd.DataFrame([], columns=['NumFrags', 'Count'])
        d['df_fcg_top_frags_u'] = pd.DataFrame([], columns=['idf', 'Count'])
        d['df_fcg_fcc'] = pd.DataFrame([], columns=['fcc', 'Count'])
        d['df_fcg_fc'] = pd.DataFrame([], columns=['mol_frag_1', 'molfrag_2', 'Count', 'fc'])

        return d

    df_fcg = df_fcg.rename({'_frags': 'frags', '_frags_u': 'frags_u'}, axis=1)
    df_fcg['frags'] = df_fcg['frags'].map(list)
    df_fcg['frags_u'] = df_fcg['frags'].map(lambda x: list(set(x)))
    groups = df_fcg[['idm', 'idfcg', 'nfrags']].groupby('idm')
    num_tot_mol_graph = len(groups)

    # number of fragment graphs per molecule
    df_fcg_nfcgpermol = groups.count().rename({'idfcg': 'NumFCG'}, axis=1).groupby('NumFCG').count().rename({'nfrags': 'Count'}, axis=1).reset_index()

    # df_edges is used to define df_fcg_fcc, but it was plagued with duplicate combinations (common parts in alternate fcg)
    # this snippet regenerates df_edges, not from the graph (which does not contain the occurrence id of the fragments),
    # but from the fcg_str, which does.
    idm_groups = df_fcg.groupby('idm')
    rows = []
    cols = ['idm', 'idfcg', 'idf1', 'fcp_1', 'fcc', 'idf2', 'fcp_2']
    for gid, g in idm_groups:
        combinations = set('-'.join(g['fcg_str']).split('-'))
        for comb in combinations:
            f1 = comb.split(':')[0]
            f2 = comb.split(']')[1].split(':')[0]
            fcc = comb.split('[')[1].split(']')[0]
            fcp1 = comb.split('@')[1].split('[')[0]
            fcp2 = comb.split('@')[-1]
            rows.append([gid, -1, f1, fcp1, fcc, f2, fcp2])
    df_edges = pd.DataFrame(rows, columns=cols)

    # fs analysis

    # initialization of a common df useful for fs analysis
    df_fcg['aidxs'] = df_fcg['_d_aidxs'].map(lambda x: [v for l in x.values() for v in l])  # extract all values from dict: list of tuples
    df_fcg['aidxs'] = df_fcg['aidxs'].map(lambda x: [v for l in x for v in l])   # flatten the list and identify only unique atom indices

    groups = df_fcg.groupby('idm')
    df_fcg_grouped = groups.agg({'frags': 'sum'}).reset_index(drop=True)  # concatenate lists in the same group
    df_fcg_grouped['frags'] = df_fcg_grouped['frags'].map(lambda x: list(set(x)))  # count each occurrence of a fragment
    df_fcg_grouped['n_frags'] = df_fcg_grouped['frags'].map(lambda x: len(x))

    # fragment hits per mol
    df_fcg_nhits = df_fcg_grouped[['n_frags', 'frags']].groupby('n_frags').count().reset_index().rename({'frags': 'Count', 'n_frags': 'NumFrags'}, axis=1)
    df_fcg_nhits['Perc_Mols'] = df_fcg_nhits['Count'].map(lambda x: f"{x / num_tot_mol_graph:.2%}")

    # top fragments

    # process fm data
    #df_fcg['frags'] = df_fcg['_frags'].map(lambda x: eval(''.join(x)))  # NOT IDEAL BUT REQUIRED AS DECODE=FALSE WHEN LOADING
    df_fcg_top_frags = df_fcg_grouped['frags'].apply(pd.Series).reset_index().melt(id_vars='index').dropna()[['index', 'value']]  # ungroup values by frag id in list

    df_fcg_top_frags['value'] = df_fcg_top_frags['value'].map(lambda x: x.split(':')[0])
    df_fcg_top_frags = df_fcg_top_frags.groupby('value').count().reset_index().rename({'value': 'idf', 'index': 'Count'}, axis=1).sort_values('Count', ascending=False).reset_index(drop=True)  # count and sort idfs
    n_fcg_nhits_tot = df_fcg_top_frags['Count'].sum()

    # fragment ratio per molecule
    df_fcg_frag_ratio = groups.agg({'aidxs': 'sum', 'hac_mol': 'first', 'mol': 'first'}).reset_index()  # concatenate all aidxs obtained previously
    df_fcg_frag_ratio['hac_frags'] = df_fcg_frag_ratio['aidxs'].map(lambda x: len(set(x)))  # the length of atom indices is the number of hac in fragments
    df_fcg_frag_ratio['frag_ratio'] = df_fcg_frag_ratio['hac_frags'] / df_fcg_frag_ratio['hac_mol']

    # add fragment ratio without side chains
    df_fcg_frag_ratio['tmp'] = df_fcg_frag_ratio.apply(lambda x: compute_mol_coverage_without_side_chains(x['mol'], x['aidxs']), axis=1)
    df_fcg_frag_ratio['hac_mol_wo_side_chain'] = df_fcg_frag_ratio['tmp'].map(lambda x: x[0])
    df_fcg_frag_ratio['hac_frag_wo_side_chain'] = df_fcg_frag_ratio['tmp'].map(lambda x: x[1])
    df_fcg_frag_ratio['frag_ratio_wo_side_chain'] = df_fcg_frag_ratio['tmp'].map(lambda x: x[2])
    df_fcg_frag_ratio.drop('tmp', axis=1, inplace=True)

    df_fcg_frag_ratio.drop(['aidxs', 'mol'], axis=1, inplace=True)

    # unique fragment hits per mol
    df_fcg_grouped['frags_u'] = df_fcg_grouped['frags'].map(lambda x: list(set([v.split(':')[0] for v in x])))
    df_fcg_grouped['n_frags_u'] = df_fcg_grouped['frags_u'].map(lambda x: len(x))
    df_fcg_nhits_u = df_fcg_grouped[['n_frags_u', 'frags_u']].groupby('n_frags_u').count().reset_index().rename({'frags_u': 'Count', 'n_frags_u': 'NumFrags'}, axis=1)
    n_fcg_nhits_u_tot = df_fcg_nhits_u['Count'].sum()

    # top unique fragments
    df_fcg_top_frags_u = df_fcg_grouped['frags_u'].apply(pd.Series).reset_index().melt(id_vars='index').dropna()[['index', 'value']]  # ungroup values by frag id in list
    df_fcg_top_frags_u = df_fcg_top_frags_u.groupby('value').count().reset_index().rename({'value': 'idf', 'index': 'Count'}, axis=1).sort_values('Count', ascending=False).reset_index(drop=True)  # count and sort idfs

    # fragment combination categories and top fragment combinations
    ds_frags = list(df_fcg['_d_mol_frags'].map(lambda x: {str(k): Chem.MolToSmiles(v) for k, v in x.items()}).values)
    d_frags = {}
    [d_frags.update(x) for x in ds_frags]
    df_edges['mol_frag_1'] = df_edges['idf1'].map(lambda x: d_frags[x])
    df_edges['mol_frag_2'] = df_edges['idf2'].map(lambda x: d_frags[x])
    d_tmp = get_dfs_fcc_from_df_fc(df_edges)

    d['num_tot_fcg_graph'] = num_tot_fcg_graph
    d['num_tot_mol_graph'] = num_tot_mol_graph
    d['n_fcg_nhits_tot'] = n_fcg_nhits_tot
    d['n_fcg_nhits_u_tot'] = n_fcg_nhits_u_tot
    d['df_fcg_nfcgpermol'] = df_fcg_nfcgpermol
    d['df_fcg_nhits'] = df_fcg_nhits
    d['df_fcg_top_frags'] = df_fcg_top_frags
    d['df_fcg_frag_ratio'] = df_fcg_frag_ratio
    d['df_fcg_nhits_u'] = df_fcg_nhits_u
    d['df_fcg_top_frags_u'] = df_fcg_top_frags_u
    d['df_fcg_fcc'] = d_tmp['df_fcc']
    d['df_fcg_fc'] = d_tmp['df_fc']

    return d


def get_df_fcg(WD: Path, subset: str = 'pnp') -> DataFrame:
    """Get a list of DFs summarizing the Fragment Graph Generation step.
    """
    if not isinstance(WD, Path):
        WD = Path(WD)
    # define data
    WD_FCGRAPH = [str(x) for x in list(WD.glob("*"))][0]
    if subset is None:  # i.e. natural (fcg)
        pattern = '.*_([0-9]{3})?_.*.csv.gz'
    else:  # pnp or npl
        pattern = '.*_([0-9]{3})?_' + subset + '.csv.gz'
    chunks = report._get_chunks(WD, pattern)
    categories = fragment_combination.get_fragment_combination_categories()

    # initialize chunk iteration
    logger.info(f"FCG -- STARTING CHUNK ITERATION IN {WD} ({len(chunks)})...")

    # chunk iteration
    logger.setLevel(logging.WARNING)  # function has loggings at info level that would flood the log file because of iteration
    pool = Pool()
    results = pool.map(parse_chunk_fcg, chunks)
    pool.close()
    pool.join()


    # counts
    num_tot_fcg_graph = sum([x['num_tot_fcg_graph'] for x in results])
    num_tot_mol_graph = sum([x['num_tot_mol_graph'] for x in results])
    n_fcg_nhits_tot = sum([x['n_fcg_nhits_tot'] for x in results])
    n_fcg_nhits_u_tot = sum([x['n_fcg_nhits_u_tot'] for x in results])
    # dfs
    dfs_fcg_nfcgpermol = [x['df_fcg_nfcgpermol'] for x in results]
    dfs_fcg_top_frags = [x['df_fcg_top_frags'] for x in results]
    dfs_fcg_frag_ratio = [x['df_fcg_frag_ratio'] for x in results]
    dfs_fcg_nhits = [x['df_fcg_nhits'] for x in results]
    dfs_fcg_nhits_u = [x['df_fcg_nhits_u'] for x in results]
    dfs_fcg_top_frags_u = [x['df_fcg_top_frags_u'] for x in results]
    dfs_fcg_fcc = [x['df_fcg_fcc'] for x in results]
    dfs_fcg_fc = [x['df_fcg_fc'] for x in results]

    logger.setLevel(logging.INFO)  # ok now back to normal
    logger.info("FCG -- COMPLETED CHUNK ITERATION...")

    # fcg_nfcgpermol
    logger.info("FCG -- RESULTS FOR THE NUMBER OF FRAGMENT GRAPHS PER MOLECULE")
    logger.info(f"FCG -- TOTAL NUMBER OF FRAGMENT GRAPHS: {num_tot_fcg_graph:,d}")
    logger.info(f"FCG -- TOTAL NUMBER OF MOLECULES: {num_tot_mol_graph:,d}")
    df_fcg_nfcgpermol = pd.concat(dfs_fcg_nfcgpermol).groupby('NumFCG').sum().reset_index()
    df_fcg_nfcgpermol['Perc_Mols'] = df_fcg_nfcgpermol['Count'].map(lambda x: f"{x / num_tot_mol_graph:.2%}")
    logger.info(f"FCG -- RESULTS FOR THE NUMBER OF FCG PER MOLECULE:\n\n{df_fcg_nfcgpermol}\n")

    # fcg_nhits
    logger.info("FCG -- INVESTIGATING FOR THE NUMBER OF FRAGMENT HITS PER MOLECULE")
    df_fcg_nhits = pd.concat(dfs_fcg_nhits).groupby('NumFrags').sum().reset_index().sort_values('NumFrags').reset_index(drop=True)
    df_fcg_nhits['Perc_Mols'] = df_fcg_nhits['Count'].map(lambda x: f"{x / num_tot_mol_graph:.2%}")
    logger.info(f"FCG -- RESULTS FOR THE NUMBER OF FRAGMENT HITS PER MOLECULE\n\n{df_fcg_nhits}\n")

    # fcg_top_frags
    logger.info("FCG -- INVESTIGATING FOR THE TOP FRAGMENTS")
    df_fcg_top_frags = pd.concat(dfs_fcg_top_frags).groupby('idf').sum().reset_index().sort_values('Count', ascending=False).reset_index(drop=True)
    df_fcg_top_frags['Rank'] = df_fcg_top_frags.index + 1
    logger.info(f"FCG -- TOTAL NUMBER OF FRAGMENT HITS={n_fcg_nhits_tot:,}")
    df_fcg_top_frags['Perc_FHits'] = df_fcg_top_frags['Count'].map(lambda x: f"{x / n_fcg_nhits_tot:.2%}")
    df_fcg_top_frags['idf'] = df_fcg_top_frags['idf'].astype(str)
    logger.info(f"FCG -- RESULTS FOR THE TOP FRAGMENTS\n\n{df_fcg_top_frags}\n")

    # fcg_frag_ratio
    logger.info("FCG -- INVESTIGATING FOR THE RATIO OF FRAGMENT PER MOLECULE")
    df_fcg_frag_ratio = pd.concat(dfs_fcg_frag_ratio).reset_index(drop=True)
    logger.info(f"FCG -- RESULTS FOR THE RATIO OF FRAGMENT PER MOLECULE\n\n{df_fcg_frag_ratio}\n")  # only a subset will be printed by pandas if many entries

    # fcg_nhits_u
    logger.info("FCG -- INVESTIGATING THE NUMBER OF UNIQUE FRAGMENT HITS PER MOLECULE")
    df_fcg_nhits_u = pd.concat(dfs_fcg_nhits_u).groupby('NumFrags').sum().reset_index().sort_values('NumFrags').reset_index(drop=True)
    df_fcg_nhits_u['Perc_Mols'] = df_fcg_nhits_u['Count'].map(lambda x: f"{x / num_tot_mol_graph:.2%}")
    logger.info(f"FCG -- RESULTS FOR THE NUMBER OF UNIQUE FRAGMENT HITS PER MOLECULE\n\n{df_fcg_nhits_u}\n")

    # fcg_top_frags_u
    logger.info("FCG -- INVESTIGATING THE TOP UNIQUE FRAGMENTS")
    df_fcg_top_frags_u = pd.concat(dfs_fcg_top_frags_u).groupby('idf').sum().reset_index().sort_values('Count', ascending=False).reset_index(drop=True)
    df_fcg_top_frags_u['Rank'] = df_fcg_top_frags_u.index + 1
    logger.info(f"FCG -- TOTAL NUMBER OF UNIQUE FRAGMENT HITS={n_fcg_nhits_u_tot:,}")
    df_fcg_top_frags_u['Perc_FHits'] = df_fcg_top_frags_u['Count'].map(lambda x: f"{x / n_fcg_nhits_u_tot:.2%}")
    df_fcg_top_frags_u['idf'] = df_fcg_top_frags_u['idf'].astype(str)
    logger.info(f"FCG -- RESULTS FOR THE TOP UNIQUE FRAGMENTS\n\n{df_fcg_top_frags_u}\n")

    # fcg_fcc
    logger.info("FCG -- INVESTIGATING THE FCC COUNTS")
    df_fcg_fcc = pd.concat(dfs_fcg_fcc).groupby('fcc').sum().T  # use transposition for sorting cols in predefined order
    df_fcg_fcc = df_fcg_fcc[categories].T.reset_index()  # once it is all good, transpose again to get rows in expected order
    n_fcg_fcc = len(df_fcg_fcc[df_fcg_fcc['Count'] > 0])
    logger.info(f"FCG -- TOTAL NUMBER OF FRAGMENT COMBINATIONS CATEGORIES IDENTIFIED={n_fcg_fcc:,}")
    n_fcg_fc = df_fcg_fcc['Count'].sum()
    logger.info(f"FCG -- TOTAL NUMBER OF FRAGMENT COMBINATIONS={n_fcg_fc:,}")
    df_fcg_fcc['Perc'] = df_fcg_fcc['Count'].map(lambda x: f"{x / n_fcg_fc:.2%}")
    logger.info(f"FCG -- RESULTS FOR THE FCC COUNTS\n\n{df_fcg_fcc}\n")

    # fcg_fc
    logger.info("FCG -- INVESTIGATING THE TOP FRAGMENT COMBINATIONS")
    df_fcg_fc = pd.concat(dfs_fcg_fc)
    df_fcg_fc = df_fcg_fc.groupby(['fc', 'mol_frag_1', 'mol_frag_2']).sum().reset_index().sort_values('Count', ascending=False).reset_index(drop=True)
    df_fcg_fc['Perc'] = df_fcg_fc['Count'].map(lambda x: f"{x / n_fcg_fc:.2%}")
    df_fcg_fc['Rank'] = df_fcg_fc.index + 1
    logger.info(f"FCG -- RESULT FOR THE TOP FRAGMENT COMBINATIONS\n\n{df_fcg_fc}\n")

    return (df_fcg_nhits, df_fcg_nhits_u, df_fcg_frag_ratio, df_fcg_top_frags, df_fcg_top_frags_u, df_fcg_fcc, df_fcg_fc, df_fcg_nfcgpermol)



# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':

    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ARGS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    d0 = datetime.now()
    parser = argparse.ArgumentParser(description="Compute all required files for analyzing FCC results", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('wd', type=str, default=None, help="Working directory where the data to parse is")
    parser.add_argument('wd_out', type=str, default=None, help="Output directory")
    parser.add_argument('-s', '--subset', type=str, default=None, help="For synthetic compounds only. In the same input wd, pnp and npl output files might be mixed. Specify here if you wish to parse 'pnp' or 'npl' compounds.")
    parser.add_argument('-d', '--dataset', type=str, default=None, help="Dataset name for using in the csv/png outputs in the report folder.")
    parser.add_argument('--prefix', type=str, default=None, help="Prefix used for output files in the data/log folders.")
    parser.add_argument('-c', '--color', type=str, default='black', help="Color to use for plots.")
    parser.add_argument('--plotformat', type=str, default='svg', help="Format to use for plots. Possible values are 'svg' and 'png'.")
    parser.add_argument('--csv', type=str, default=False, help="Generate only CSV output files")
    parser.add_argument('--clear', type=str, default=False, help="Force the generation of log, plot and CSV files by clearing all report files at any found specified levels.")
    parser.add_argument('--regenplots', type=str, default=False, help="Force the geeration of plots by clearing any pre-existing plot at any specified levels.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # check arguments

    utils.check_arg_input_dir(args.wd)

    # prefix
    if args.prefix is None:
        logging.warning("PREFIX IS NOT SET, RESORTING TO WD DIRNAME.")
        prefix = Path(args.wd).name
    else:
        prefix = args.prefix
    if args.dataset is None:
        logging.warning("DATASET IS NOT SET, USING PREFIX INSTEAD.")
        dataset = prefix
    else:
        dataset = args.dataset

    # plotformat
    if args.plotformat not in ('svg', 'png'):
        raise ValueError(f"ERROR! UNKNOWN PLOT FORMAT! ('{args.plotformat}')")

    # subset
    if args.subset is not None:
        subset = args.subset.lower()
        if subset not in ('pnp', 'npl'):
            raise ValueError(f"ERROR! UNKNOWN SUBSET! ('{args.subset}')")
    else:
        subset = None


    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INIT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    # logging
    global logger  # desperate attempt

    pd.options.mode.chained_assignment = None  # disable pd.io.pytables.SettingWithCopyWarning
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.INFO)
    log_file = f"{args.wd_out}/report_fcg_{prefix}.log"
    utils.check_arg_output_file(log_file)
    logger = utils._configure_logger(log_level=args.log, log_file=log_file, logger_name=log_file)

    # display rendering
    pd.set_option('display.max_columns', 20)
    pd.set_option('display.max_rows', 20)
    pd.set_option('max_colwidth', 70)
    pad_title = 80
    pad = 60

    color = report.DEFAULT_PALETTE.get(args.color, args.color)



logger.info("ARGUMENTS")
logger.info("WD_in".ljust(pad) + f"{args.wd}")
logger.info("WD_out".ljust(pad) + f"{args.wd_out}")
logger.info("log_file".ljust(pad) + f"{log_file}")
logger.info("clear".ljust(pad) + f"{args.clear}")
logger.info("regenplots".ljust(pad) + f"{args.regenplots}")


report.print_title("FRAGMENT GRAPH GENERATION", 3, pad_title)
# define outputs
output_csv_fcg_nfragpermol = f"{args.wd_out}/{prefix}_fcg_nfragpermol.csv"
output_csv_fcg_nfragpermol_u = f"{args.wd_out}/{prefix}_fcg_nfragpermol_u.csv"
output_csv_fcg_fragmolcov = f"{args.wd_out}/{prefix}_fcg_fragmolcov.csv"
output_csv_fcg_top10frags = f"{args.wd_out}/{prefix}_fcg_top10frags.csv"
output_csv_fcg_top10frags_u = f"{args.wd_out}/{prefix}_fcg_top10frags_u.csv"
output_csv_fcg_fcc = f"{args.wd_out}/{prefix}_fcg_fcc.csv"
output_csv_fcg_top10fc = f"{args.wd_out}/{prefix}_fcg_fc.csv"
output_csv_fcg_nfcgpermol = f"{args.wd_out}/{prefix}_fcg_nfcgpermol.csv"
output_plot_fcg_nfragpermol = output_csv_fcg_nfragpermol.replace('.csv', f".{args.plotformat}")
output_plot_fcg_nfragpermol_zoom = output_csv_fcg_nfragpermol.replace('.csv', f"_zoom.{args.plotformat}")
output_plot_fcg_nfragpermol_u = output_csv_fcg_nfragpermol_u.replace('.csv', f".{args.plotformat}")
output_plot_fcg_nfragpermol_u_zoom = output_csv_fcg_nfragpermol_u.replace('.csv', f"_zoom.{args.plotformat}")
output_plot_fcg_fragmolcov = output_csv_fcg_fragmolcov.replace('.csv', f".{args.plotformat}")
output_plot_fcg_fragmolcov_wo_side_chain = f"{args.wd_out}/{Path(output_plot_fcg_fragmolcov).stem}_wo_side_chain.{args.plotformat}"
output_plot_fcg_top10frags = output_csv_fcg_top10frags.replace('.csv', f".{args.plotformat}")
output_plot_fcg_top10frags_u = output_csv_fcg_top10frags_u.replace('.csv', f".{args.plotformat}")
output_plot_fcg_fcc = output_csv_fcg_fcc.replace('.csv', f".{args.plotformat}")
output_plot_fcg_top10fc = output_csv_fcg_top10fc.replace('.csv', f".{args.plotformat}")
output_plot_fcg_nfcgpermol = output_csv_fcg_nfcgpermol.replace('.csv', f".{args.plotformat}")
output_plot_fcg_nfcgpermol_zoom = output_csv_fcg_nfcgpermol.replace('.csv', f"_zoom.{args.plotformat}")
logger.info("FCG -- OUTPUT_CSV_FCG_NFRAGPERMOL".ljust(pad) + f"{output_csv_fcg_nfragpermol}")
logger.info("FCG -- OUTPUT_CSV_FCG_NFRAGPERMOL_U".ljust(pad) + f"{output_csv_fcg_nfragpermol_u}")
logger.info("FCG -- OUTPUT_CSV_FCG_FRAGMOLCOV".ljust(pad) + f"{output_csv_fcg_fragmolcov}")
logger.info("FCG -- OUTPUT_CSV_FCG_TOP10FRAGS".ljust(pad) + f"{output_csv_fcg_top10frags}")
logger.info("FCG -- OUTPUT_CSV_FCG_TOP10FRAGS_U".ljust(pad) + f"{output_csv_fcg_top10frags_u}")
logger.info("FCG -- OUTPUT_CSV_FCG_FCC".ljust(pad) + f"{output_csv_fcg_fcc}")
logger.info("FCG -- OUTPUT_CSV_FCG_TOP10FC".ljust(pad) + f"{output_csv_fcg_top10fc}")
logger.info("FCG -- OUTPUT_CSV_FCG_NFRAGGRAPHPERMOL".ljust(pad) + f"{output_csv_fcg_nfcgpermol}")
logger.info("FCG -- OUTPUT PLOT FILES HAVE THE SAME FILE NAMES AS OUTPUT CSV FILES")
# retrieve data
output_csv_files = [output_csv_fcg_nfragpermol, output_csv_fcg_nfragpermol_u,
                    output_csv_fcg_fragmolcov,
                    output_csv_fcg_top10frags, output_csv_fcg_top10frags_u,
                    output_csv_fcg_fcc, output_csv_fcg_top10fc,
                    output_csv_fcg_nfcgpermol,
                    ]
output_plot_files = [output_plot_fcg_nfragpermol, output_plot_fcg_nfragpermol_u,
                     output_plot_fcg_fragmolcov, output_plot_fcg_fragmolcov_wo_side_chain,
                     output_plot_fcg_nfragpermol_zoom, output_plot_fcg_nfragpermol_u_zoom,
                     output_plot_fcg_top10frags, output_plot_fcg_top10frags_u,
                     output_plot_fcg_fcc, output_plot_fcg_top10fc,
                     output_plot_fcg_nfcgpermol, output_plot_fcg_nfcgpermol_zoom,
                     ]


if args.regenplots:
    logger.info('REMOVING EXISTING OUTPUT PLOT FILES...')
    for x in output_plot_files:
        if Path(x).exists(): Path(x).unlink()

if args.clear:
    logger.info('REMOVING EXISTING OUTPUT PLOT AND CSV FILES...')
    for x in output_plot_files + output_csv_files:
        if Path(x).exists(): Path(x).unlink()

if all([Path(x).exists() for x in output_plot_files]):
    logger.info("FCG -- ALL OUTPUT PLOT FILES ARE ALREADY AVAILABLE, NOTHING TO DO!")
elif all([Path(x).exists() for x in output_csv_files]):
    logger.info("FCG -- PARSING OUTPUT CSV FILES INSTEAD OF COMPUTING THEM")
    df_fcg_nfragpermol = load.file(output_csv_fcg_nfragpermol)
    df_fcg_nfragpermol_u = load.file(output_csv_fcg_nfragpermol_u)
    df_fcg_fragmolcov = load.file(output_csv_fcg_fragmolcov)
    df_fcg_top10frags = load.file(output_csv_fcg_top10frags, decode=False).head(10)
    df_fcg_top10frags_u = load.file(output_csv_fcg_top10frags_u, decode=False).head(10)
    df_fcg_fcc = load.file(output_csv_fcg_fcc)
    df_fcg_top10fc = load.file(output_csv_fcg_top10fc, decode=False).head(10)  # mol_frag_1/2 but actually smiles
    df_fcg_nfcgpermol = load.file(output_csv_fcg_nfcgpermol)
else:
    logger.info("FCG -- COMPUTING OUTPUT CSV FILES")
    dfs_fcg = get_df_fcg(args.wd, subset=subset)
    for df_fcg, output_csv_fm in zip(dfs_fcg, output_csv_files):
        save.file(df_fcg, output_csv_fm, encode=False)
    df_fcg_nfragpermol = dfs_fcg[0]
    df_fcg_nfragpermol_u = dfs_fcg[1]
    df_fcg_fragmolcov = dfs_fcg[2]
    df_fcg_top10frags = dfs_fcg[3].head(10)
    df_fcg_top10frags_u = dfs_fcg[4].head(10)
    df_fcg_fcc = dfs_fcg[5]
    df_fcg_top10fc = dfs_fcg[6].head(10)
    df_fcg_nfcgpermol = dfs_fcg[7]
d5 = datetime.now()

# skip plots if computing only CSV output files
if not args.csv:

    # plot output_plot_fs_nfragpermol_u
    if Path(output_plot_fcg_nfragpermol).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL".ljust(pad) + "COMPUTING...")
        report.save_distplot(df_fcg_nfragpermol,
                             output_plot_fcg_nfragpermol,
                             'NumFrags',
                             f"FCG - Number of Fragment Hits Per Molecule in {dataset}",
                             color=color,
                             x_label='Number of Fragment Hits Per Molecule',
                             y_label='Count',
                             fig_size=(24, 12),
                             )

    # plot output_plot_fcg_nfragpermol_zoom
    if Path(output_plot_fcg_nfragpermol_zoom).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_ZOOM".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_ZOOM".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_fcg_nfragpermol.head(20),
                            output_plot_fcg_nfragpermol_zoom,
                            'NumFrags',
                            'Count',
                            f"FCG - Number of Fragment Hits Per Molecule in {dataset} (zoom)",
                            x_label='Number of Fragment Hits Per Molecule',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc_Mols',
                            )

    # plot output_plot_fcg_fragmolcov
    if Path(output_plot_fcg_fragmolcov).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_FRAGMOLCOV".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_FRAGMOLCOV".ljust(pad) + "COMPUTING...")
        report.save_kdeplot(df_fcg_fragmolcov,
                            output_plot_fcg_fragmolcov,
                            x_name='frag_ratio',
                            title=f"FCG - Distribution of Molecule Coverage by Fragments in {dataset}",
                            x_label='Molecule Coverage by Fragments',
                            y_label='Kernel Density Estimate of the Number of Molecules',
                            color=color,
                            )

    # plot output_plot_fcg_fragmolcov_wo_side_chain
    if Path(output_plot_fcg_fragmolcov_wo_side_chain).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_FRAGMOLCOV_WO_SIDE_CHAIN".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_FRAGMOLCOV_WO_SIDE_CHAIN".ljust(pad) + "COMPUTING...")
        report.save_kdeplot(df_fcg_fragmolcov,
                            output_plot_fcg_fragmolcov_wo_side_chain,
                            x_name='frag_ratio_wo_side_chain',
                            title=f"FCG - Distribution of Molecule Coverage by Fragments in {dataset}",
                            x_label='Molecule Coverage by Fragments (without side chains)',
                            y_label='Kernel Density Estimate of the Number of Molecules',
                            color=color,
                            )

    # plot output_plot_fcg_top10frags
    if Path(output_plot_fcg_top10frags).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_TOP10FRAGS".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_TOP10FRAGS".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_fcg_top10frags,
                            output_plot_fcg_top10frags,
                            'idf',
                            'Count',
                            f"FCG - Top 10 Fragments by Occurrence in {dataset}",
                            x_label='Fragment ID',
                            y_label='Count',
                            color=color,
                            rotate_x=45,
                            perc_labels='Perc_FHits',
                            force_order=True,
                            fig_size=(12, 12),
                            )

    # plot output_plot_fs_nfragpermol_u
    if Path(output_plot_fcg_nfragpermol_u).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_U".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_U".ljust(pad) + "COMPUTING...")
        report.save_distplot(df_fcg_nfragpermol_u,
                             output_plot_fcg_nfragpermol_u,
                             'NumFrags',
                             f"FCG - Number of Unique Fragment Hits Per Molecule in {dataset}",
                             color=color,
                             x_label='Number of Unique Fragment Hits Per Molecule',
                             y_label='Count',
                             fig_size=(24, 12),
                             )

    # plot output_plot_fcg_nfragpermol_u_zoom
    if Path(output_plot_fcg_nfragpermol_u_zoom).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_U_ZOOM".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_U_ZOOM".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_fcg_nfragpermol_u.head(20),
                            output_plot_fcg_nfragpermol_u_zoom,
                            'NumFrags',
                            'Count',
                            f"FCG - Number of Unique Fragment Hits Per Molecule in {dataset} (zoom)",
                            x_label='Number of Unique Fragment Hits Per Molecule',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc_Mols',
                            )

    # plot output_plot_fcg_top10frags_u
    if Path(output_plot_fcg_top10frags_u).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_TOP10FRAGS_U".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_TOP10FRAGS_U".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_fcg_top10frags_u,
                            output_plot_fcg_top10frags_u,
                            'idf',
                            'Count',
                            f"FCG - Top 10 Unique Fragments by Occurrence in {dataset}",
                            x_label='Fragment ID',
                            y_label='Count',
                            color=color,
                            rotate_x=45,
                            perc_labels='Perc_FHits',
                            force_order=True,
                            fig_size=(12, 12),
                            )

    # plot output_plot_fcg_fcc
    if Path(output_plot_fcg_fcc).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_FCC".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_FCC".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_fcg_fcc,
                            output_plot_fcg_fcc,
                            'fcc',
                            'Count',
                            f"FCG - Fragment Combination Classification in {dataset}",
                            x_label='Fragment Combination Categories',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc',
                            )

    # plot output_plot_fcg_top10fc
    if Path(output_plot_fcg_top10fc).exists():
        logger.info("FCG -- OUTPUT_PLOT_FC_COUNTS".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FC_COUNTS".ljust(pad) + "COMPUTING...")  # for that one percentage labelling go crazy...!
        df_fcg_top10fc['fc'] = df_fcg_top10fc['fc'].map(lambda x: x.replace('[', '\n').replace(']', '\n'))
        report.save_barplot(df_fcg_top10fc,
                            output_plot_fcg_top10fc,
                            'fc',
                            'Count',
                            f"FCG - Top 10 Fragment Combinations by Occurence in {dataset}",
                            x_label='Fragment Combinations',
                            y_label='Count',
                            color=color,
                            rotate_x=0,
                            perc_labels='Perc',
                            fig_size=(28, 12)
                            )

    # plot output_plot_fcg_nfcgpermol
    if Path(output_plot_fcg_nfcgpermol).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_U".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFRAGPERMOL_U".ljust(pad) + "COMPUTING...")
        report.save_distplot(df_fcg_nfcgpermol,
                             output_plot_fcg_nfcgpermol,
                             'NumFCG',
                             f"FCG - Number of Fragment Graphs Per Molecule in {dataset}",
                             color=color,
                             x_label='Number of Fragment Graphs Per Molecule',
                             y_label='Count',
                             fig_size=(24, 12),
                             )

    # plot output_plot_fcg_nfcgpermol_zoom
    if Path(output_plot_fcg_nfcgpermol_zoom).exists():
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFCGPERMOL_ZOOM".ljust(pad) + "ALREADY DONE")
    else:
        logger.info("FCG -- OUTPUT_PLOT_FCG_NFCGPERMOL_ZOOM".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_fcg_nfcgpermol,
                            output_plot_fcg_nfcgpermol_zoom,
                            'NumFCG',
                            'Count',
                            f"FCG - Number of Fragment Graphs Per Molecule in {dataset} (zoom)",
                            x_label='Number of Fragment Graphs Per Molecule',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc_Mols',
                            )
d6 = datetime.now()
logger.info("-- END OF REPORT")
logger.info("-- COMPUTATIONAL TIME")
# logger.info("FS - PARSE OUTPUT FILES:".ljust(pad) + f"{d1 - d0}")
# logger.info("FS - GENERATE PLOTS:".ljust(pad) + f"{d2 - d1}")
# logger.info("FC - PARSE OUTPUT FILES:".ljust(pad) + f"{d3 - d2}")
# logger.info("FC - GENERATE PLOTS:".ljust(pad) + f"{d4 - d3}")
logger.info("TOTAL:".ljust(pad) + f"{d6 - d0}")
