#!/usr/bin/env python

"""
Script report_subset
====================
This script is used for reporting the SUBSET step.
"""

# standard
import warnings
import logging
import argparse
import sys
from shutil import rmtree
from datetime import datetime
import re
from pathlib import Path
from collections import OrderedDict
# data handling
import numpy as np
import json
import pandas as pd
from pandas import DataFrame
import networkx as nx
# data visualization
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
# from pylab import savefig
from adjustText import adjust_text
# chemoinformatics
import rdkit
from rdkit import Chem
from rdkit.Chem import Mol
# docs
from typing import List
from typing import Tuple
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import utils
from npfc import load
from npfc import save
from npfc import draw
from npfc import report
from multiprocessing import Pool
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def parse_chunk_subset(c):
    df = pd.read_csv(c, sep="@", header=None)  # char not found in the log file so we can extract all lines as one column
    records = df[df[0].str.contains("NUMBER OF REMAINING RECORDS IN SUBSET")].iloc[0][0].split()[-1].split('/')
    passed = int(records[0])
    total = int(records[1])
    filtered = total - passed
    return (passed, filtered)


def get_df_subset(WD: Path) -> DataFrame:
    """Get a DF summarizing the results of the subset step.

    At the moment only one subset is recognized (i.e. 'subset' subfolder in WD).


    :param WD: the main directory of the dataset data (i.e. 'natural/dnp/data')
    :return: a DF summarizing results of the murcko subset step
    """
    logger.info("SUB -- COMPUTING RESULTS FOR SUBSET")
    if not isinstance(WD, Path):
        WD = Path(WD)
    # parse results before fragment search
    pattern = ".*([0-9]{3})?.log"
    chunks = report._get_chunks(f"{WD}", pattern)

    pool = Pool()
    results = pool.map(parse_chunk_subset, chunks)
    pool.close()
    pool.join()

    # sum of tuples
    df = pd.DataFrame(results, columns=['passed', 'filtered'])
    num_passed = df.passed.sum()
    num_filtered = df.filtered.sum()
    num_total = num_passed + num_filtered

    # create a dataframe with counts
    df = pd.DataFrame({'Category': ['passed', 'filtered'], 'Count': [num_passed, num_filtered]})
    df['Perc_Status'] = df['Count'].map(lambda x: f"{x/num_total:.2%}")
    logger.info(f"SUB -- RESULTS FOR SUBSETTING {num_total:,} MOLECULES:\n\n{df}\n")

    return df


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script for reporting the SUBSET step, which consists in
    filtering out natural products from the dataset to achieve a synthetic subset.

    This script takes to inputs:
        - the log folder where SUBSET results are stored
        - the output folder where to save the report (log, csv and plot files)

    Some parameters can be used for adjusting the reporting, type report_subset -h
    to display them.

    This script uses the installed npfc libary in your favorite env manager.

    Example:

        >>> report_subset fc/04_synthetic/chembl/data/prep/natref_dnp/06_subset/data fc/04_synthetic/chembl/data/prep/natref_dnp/06_subset/report -d "My Data Set" --p dataset -c blue

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description="Compute all required files for analyzing FCC results", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('wd', type=str, default=None, help="Working directory where the logs to parse are")
    parser.add_argument('wd_out', type=str, default=None, help="Output directory")
    parser.add_argument('-d', '--dataset', type=str, default=None, help="Dataset name for using in the csv/png outputs in the report folder.")
    parser.add_argument('-p', '--prefix', type=str, default=None, help="Prefix used for output files in the data/log folders.")
    parser.add_argument('-c', '--color', type=str, default='black', help="Color to use for plots.")
    parser.add_argument('--plotformat', type=str, default='svg', help="Format to use for plots. Possible values are 'svg' and 'png'.")
    parser.add_argument('--csv', type=str, default=False, help="Generate only CSV output files")
    parser.add_argument('--clear', type=str, default=False, help="Force the generation of log, plot and CSV files by clearing all report files at any found specified levels.")
    parser.add_argument('--regenplots', type=str, default=False, help="Force the geeration of plots by clearing any pre-existing plot at any specified levels.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # check arguments

    # I/O
    utils.check_arg_input_dir(args.wd)
    utils.check_arg_output_dir(args.wd_out)
    wd = Path(args.wd)
    wd_out = Path(args.wd_out)

    # prefix
    if args.prefix is None:
        logging.warning("PREFIX IS NOT SET, RESORTING TO WD DIRNAME.")
        prefix = Path(args.wd).name
    else:
        prefix = args.prefix

    # dataset
    if args.dataset is None:
        logging.warning("DATASET IS NOT SET, USING PREFIX INSTEAD.")
        dataset = prefix
    else:
        dataset = args.dataset

    # plotformat
    if args.plotformat not in ('svg', 'png'):
        raise ValueError(f"ERROR! UNKNOWN PLOT FORMAT! ('{args.plotformat}')")


    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INIT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    # logging
    global logger  # desperate attempt

    pd.options.mode.chained_assignment = None  # disable pd.io.pytables.SettingWithCopyWarning
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.INFO)
    log_file = f"{args.wd_out}/report_subset_{prefix}.log"
    utils.check_arg_output_file(log_file)
    logger = utils._configure_logger(log_level=args.log, log_file=log_file, logger_name=log_file)

    # display rendering
    report.init_report_globals()
    pad_title = 80
    pad = 60
    color = report.DEFAULT_PALETTE.get(args.color, args.color)

    # display infos
    logger.info("LIBRARY VERSIONS")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS")
    logger.info("WD_in".ljust(pad) + f"{args.wd}")
    logger.info("WD_out".ljust(pad) + f"{args.wd_out}")
    logger.info("log_file".ljust(pad) + f"{log_file}")
    logger.info("color".ljust(pad) + f"{color}")
    logger.info("dataset".ljust(pad) + f"{dataset}")
    logger.info("prefix".ljust(pad) + f"{prefix}")
    logger.info("clear".ljust(pad) + f"{args.clear}")
    logger.info("regenplots".ljust(pad) + f"{args.regenplots}")
    logger.info("plotformat".ljust(pad) + f"{args.plotformat}")
    logger.info("log".ljust(pad) + f"{args.log}")
    d0 = datetime.now()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


    report.print_title("REPORTING SUBSET", 3, pad_title)

    # parse wd
    df_sub_subset = get_df_subset(wd)
    d1 = datetime.now()

    # define csv outputs
    output_csv_subset_subset = f"{wd_out}/data/{prefix}_subset_subset.csv"
    save.file(df_sub_subset, output_csv_subset_subset)
    logger.info("SUBSET -- OUTPUT_CSV_SUBSET_SUBSET".ljust(pad) + f"{output_csv_subset_subset}")
    d2 = datetime.now()


    if not args.csv:
        Path(f"{args.wd_out}/plot").mkdir(exist_ok=True)

        # define plot outputs
        output_plot_subset_subset = f"{wd_out}/plot/{prefix}_subset_subset.{args.plotformat}"
        logger.info("SUBSET -- OUTPUT_PLOT_SUBSET_SUBSET".ljust(pad) + f"{output_plot_subset_subset}")

        # plot output_plot_subset_subset
        logger.info("SUBSET -- OUTPUT_PLOT_SUBSET_SUBSET".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_sub_subset,
                    output_plot_subset_subset,
                    'Category',
                    'Count',
                    f"Creating a Synthetic Subset of {dataset}",
                    x_label='Category',
                    y_label='Count',
                    color=color,
                    perc_labels='Perc_Status',
                    fig_size=(12, 12),
                    )
    d3 = datetime.now()

    # end
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: ITERATING OVER CHUNKS".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: EXPORTING CSV".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: EXPORTING PLOTS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d3-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
