#!/usr/bin/env python

"""
Script report_pnp
==========================
This script is used for counting PNPs and NPLs from data chunks.
"""

# standard
import warnings
from pathlib import Path
import logging
import re
import sys
from datetime import timedelta
from datetime import datetime
import pandas as pd
import argparse
from pandas import DataFrame
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils
from npfc import report
import subprocess
from multiprocessing import Pool
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

# todo:
# 1. iterate over each chunk
# 2. count PNPs, NPLs
# 3. draw n examples per chunk
# 4. plot results


def parse_chunk_pnp(c):
    return len(load.file(c, decode=False).drop_duplicates(subset=['idm']))


def parse_chunk_npl(c):
    df = load.file(c, decode=['_pnp_ref'])
    # Aggregate all ref NP FCG
    df_npl_numnpfcgrefpernpl = df.groupby('idm').agg({'_pnp_ref': 'sum'}).rename({'_pnp_ref': 'NumRefNPFCGPerNPL'}, axis=1).reset_index()
    df_npl_numnpfcgrefpernpl['NumRefNPFCGPerNPL'] = df_npl_numnpfcgrefpernpl['NumRefNPFCGPerNPL'].map(lambda x: len(list(set(x))))
    num_npl = len(df_npl_numnpfcgrefpernpl)
    df_npl_numnpfcgrefpernpl = df_npl_numnpfcgrefpernpl.groupby('NumRefNPFCGPerNPL').count().rename({'idm': 'Count'}, axis=1).reset_index()

    return {'num_npl': num_npl, 'df_npl_numnpfcgrefpernpl': df_npl_numnpfcgrefpernpl}


def get_dfs_pnp(WD: Path) -> DataFrame:
    """Get a DF summarizing the results of the pnp step.

    Output files for PNP and NPL need to be in the same directory.

    :param WD_pnp: the main directory of the dataset data (i.e. 'natural/dnp/data')
    :return: a DF summarizing results of the murcko subset step
    """
    logger.info("PNP -- COMPUTING RESULTS FOR PNPANNOTATION - PNPs")
    if not isinstance(WD, Path):
        WD = Path(WD)
    
    pattern_pnp = ".*([0-9]{3})_pnp.csv.gz"
    chunks_pnp = report._get_chunks(f"{WD}", pattern_pnp)
    pattern_npl = ".*([0-9]{3})_npl.csv.gz"
    chunks_npl = report._get_chunks(f"{WD}", pattern_npl)

    # PNPs
    pool = Pool()
    results_pnp = pool.map(parse_chunk_pnp, chunks_pnp)
    pool.close()
    pool.join()

    # NPLs
    pool = Pool()
    results_npl = pool.map(parse_chunk_npl, chunks_npl)
    pool.close()
    pool.join()

    # sum of counts
    num_pnp = sum(results_pnp)
    num_npl = sum([x['num_npl'] for x in results_npl])
    num_total = num_pnp + num_npl
    logger.info(f"pnp: {num_pnp:,} + npl: {num_npl:,} = tot: {num_total:,}")

    # create a dataframe with counts
    df_pnp_ratio = pd.DataFrame({'Category': ['PNP', 'NPL'], 'Count': [num_pnp, num_npl]})
    df_pnp_ratio['Perc'] = df_pnp_ratio['Count'].map(lambda x: f"{x/num_total:.2%}")
    logger.info(f"PNP -- RESULTS FOR LABELLING PNPs IN {num_total:,} MOLECULES:\n\n{df_pnp_ratio}\n")

    # Number of NP refs per FCG
    dfs_npl_numnpfcgrefpernpl = [x['df_npl_numnpfcgrefpernpl'] for x in results_npl]
    df_npl_numnpfcgrefpernpl = pd.concat(dfs_npl_numnpfcgrefpernpl)[['NumRefNPFCGPerNPL', 'Count']].groupby('NumRefNPFCGPerNPL').sum().reset_index()
    tot = df_npl_numnpfcgrefpernpl['Count'].sum()
    df_npl_numnpfcgrefpernpl['Perc'] = df_npl_numnpfcgrefpernpl['Count'].map(lambda x: f"{x / tot:.2%}")
    logger.info(f"PNP -- RESULTS NUMBER OF NP REFERENCES PER FCG:\n\n{df_npl_numnpfcgrefpernpl}\n")

    return {'df_pnp_ratio': df_pnp_ratio, 'df_npl_numnpfcgrefpernpl': df_npl_numnpfcgrefpernpl}


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script for reporting the PNP annotation step.

    This script takes three inputs:
        - the folder with the PNP outputs*
        - the folder with the NPL outputs*
        - the output folder where to save the report (log, csv and plot files)

        fc/
        ├── 03_natural/
        │   ├── coconut
                ├── data
                    ├── prep
                        ├── 01_chunk
                        ├── 02_load
                        ├── 03_standardize
                        ├── 04_deduplicate
                        ├── 05_depict
                        ├── natref_coconut
                            ├── 06_subset
                            ├── frags_crms
                                ├── 07_fs
                                ├── 08_fcc
                                ├── 09_fcg
                                ├── 10_pnp
                                    ├── data *


    Some parameters can be used for adjusting the reporting, type report_prep -h
    to display them.

    This script uses the installed npfc libary in your favorite env manager.

    Example:

        >>> report_pnp fc/04_synthetic/chembl/data/prep/ fc/04_synthetic/chembl/data/prep/report -d "My Data Set" --p dataset -c blue

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description="Compute all required files for analyzing FCC results", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('wd', type=str, default=None, help="Working directory where the PNP log files are.")
    parser.add_argument('wd_out', type=str, default=None, help="Output directory")
    parser.add_argument('-d', '--dataset', type=str, default=None, help="Dataset name for using in the csv/png outputs in the report folder.")
    parser.add_argument('-p', '--prefix', type=str, default=None, help="Prefix used for output files in the data/log folders.")
    parser.add_argument('-c', '--color', type=str, default='black', help="Color to use for plots.")
    parser.add_argument('--plotformat', type=str, default='svg', help="Format to use for plots. Possible values are 'svg' and 'png'.")
    parser.add_argument('--csv', type=str, default=False, help="Generate only CSV output files")
    parser.add_argument('--clear', type=str, default=False, help="Force the generation of log, plot and CSV files by clearing all report files at any found specified levels.")
    parser.add_argument('--regenplots', type=str, default=False, help="Force the geeration of plots by clearing any pre-existing plot at any specified levels.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # check arguments

    # I/O
    utils.check_arg_input_dir(args.wd)
    utils.check_arg_output_dir(args.wd_out)
    wd = Path(args.wd)
    wd_out = Path(args.wd_out)

    # prefix
    if args.prefix is None:
        logging.warning("PREFIX IS NOT SET, RESORTING TO WD DIRNAME.")
        prefix = Path(args.wd).name
    else:
        prefix = args.prefix

    # dataset
    if args.dataset is None:
        logging.warning("DATASET IS NOT SET, USING PREFIX INSTEAD.")
        dataset = prefix
    else:
        dataset = args.dataset

    # plotformat
    if args.plotformat not in ('svg', 'png'):
        raise ValueError(f"ERROR! UNKNOWN PLOT FORMAT! ('{args.plotformat}')")


    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INIT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    # logging
    global logger  # desperate attempt

    pd.options.mode.chained_assignment = None  # disable pd.io.pytables.SettingWithCopyWarning
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    log_file = f"{args.wd_out}/report_pnpannotation_{prefix}.log"
    utils.check_arg_output_file(log_file)
    logger = utils._configure_logger(log_level=args.log, log_file=log_file, logger_name=log_file)

    # display rendering
    report.init_report_globals()
    pad_title = 80
    pad = 60
    color = report.DEFAULT_PALETTE.get(args.color, args.color)

    # display infos
    logger.info("LIBRARY VERSIONS")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS")
    logger.info("WD_in".ljust(pad) + f"{args.wd}")
    logger.info("WD_out".ljust(pad) + f"{args.wd_out}")
    logger.info("log_file".ljust(pad) + f"{log_file}")
    logger.info("color".ljust(pad) + f"{color}")
    logger.info("dataset".ljust(pad) + f"{dataset}")
    logger.info("prefix".ljust(pad) + f"{prefix}")
    logger.info("clear".ljust(pad) + f"{args.clear}")
    logger.info("regenplots".ljust(pad) + f"{args.regenplots}")
    logger.info("plotformat".ljust(pad) + f"{args.plotformat}")
    logger.info("log".ljust(pad) + f"{args.log}")
    d0 = datetime.now()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


    report.print_title("REPORTING PNP ANNOTATION", 3, pad_title)

    # parse relevant prep subdirectories
    dfs_pnp = get_dfs_pnp(wd)
    df_pnp_ratio = dfs_pnp['df_pnp_ratio']
    df_npl_numnpfcgrefpernpl = dfs_pnp['df_npl_numnpfcgrefpernpl']
    d1 = datetime.now()

    # define csv outputs
    output_csv_pnp_ratio = f"{wd_out}/data/{prefix}_pnp_ratio.csv"
    output_csv_pnp_numnpfcgrefpernpl = f"{wd_out}/data/{prefix}_pnp_numnpfcgrefpernpl.csv"
    save.file(df_pnp_ratio, output_csv_pnp_ratio)
    save.file(df_npl_numnpfcgrefpernpl, output_csv_pnp_numnpfcgrefpernpl)
    logger.info("PREP -- OUTPUT_CSV_PNP_RATIO".ljust(pad) + f"{output_csv_pnp_ratio}")
    logger.info("PREP -- OUTPUT_CSV_PNP_NUMNPFCGREFPERNPL".ljust(pad) + f"{output_csv_pnp_numnpfcgrefpernpl}")
    d2 = datetime.now()

    # define plot outputs
    output_plot_pnp_ratio = f"{wd_out}/plot/{prefix}_pnp_ratio.{args.plotformat}"
    output_plot_pnp_numnpfcgrefpernpl = f"{wd_out}/plot/{prefix}_pnp_numnpfcgrefpernpl.{args.plotformat}"
    logger.info("PREP -- OUTPUT PLOT FILE SYNTAX: OUTPUT_CSV.PLOTFORMAT")

    if not args.csv:
        Path(f"{wd_out}/plot").mkdir(exist_ok=True)

        # plot output_plot_pnp_ratio
        logger.info("PNP -- OUTPUT_PLOT_PNP_RATIO".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_pnp_ratio,
                            output_plot_pnp_ratio,
                            'Category',
                            'Count',
                            f"Ratio of PNP/NPL in {dataset}",
                            x_label='Category',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc',
                            fig_size=(12, 12),
                            )

        # plot output_plot_pnp_numnpfcgrefpernpl
        logger.info("PREP -- OUTPUT_PLOT_PNP_NUMNPFCGREFPERNPL".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_npl_numnpfcgrefpernpl,
                            output_plot_pnp_numnpfcgrefpernpl,
                            'NumRefNPFCGPerNPL',
                            'Count',
                            f"Number of Reference NP-FCG per NPL in {dataset}",
                            x_label='Category',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc'
                            # rotate_x=90,
                            )

    d3 = datetime.now()

    # end
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: ITERATING OVER CHUNKS".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: EXPORTING CSV".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: EXPORTING PLOTS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d3-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
