#!/usr/bin/env python

"""
Script report_fcp
=================
This script is used for reporting the FCP step.
"""

# standard
import warnings
import logging
import argparse
import sys
from shutil import rmtree
from datetime import datetime
import re
from pathlib import Path
from collections import OrderedDict
# data handling
import numpy as np
import json
import pandas as pd
from pandas import DataFrame
import networkx as nx
# data visualization
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
# from pylab import savefig
from adjustText import adjust_text
# chemoinformatics
import rdkit
from rdkit import Chem
from rdkit.Chem import Mol
# docs
from typing import List
from typing import Tuple
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import utils
from npfc import load
from npfc import save
from npfc import draw
from npfc import report
from multiprocessing import Pool
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def parse_chunk_fcp(c):
    df = load.file(c, decode=False)[['idm', 'num_symmetry_groups']]
    df_counts = df.groupby('num_symmetry_groups').count().reset_index().rename({'idm': 'Count', 'num_symmetry_groups': 'NumSymGroups'}, axis=1)
    return df_counts


def get_df_fcp(WD: str) -> DataFrame:
    """Get a DF summarizing the results of the fcp step.

    :param WD: the directory of the std step
    :return: a DF summarizing results of the fcp step
    """
    logger.info("FCP -- COMPUTING FCP RESULTS")
    pattern = ".*([0-9]{3})?_fcp.csv.gz"
    chunks = report._get_chunks(f"{WD}", pattern)
    logger.info(f"FCP -- FOUND {len(chunks):,d} CHUNKS")

    # parse all files
    dfs = []
    pool = Pool()
    dfs = pool.map(parse_chunk_fcp, chunks)
    pool.close()
    pool.join()

    dfs = [df for df in dfs if len(df) > 0]

    # if no case was found, return an empty dataframe
    if len(dfs) == 0:
        df = pd.DataFrame([], columns=["NumSymGroups", "Count", "Perc_Mols"])
        return df

    # concatenate all dataframes and compute the sum of all counts
    df = pd.concat(dfs)
    df = df.groupby("NumSymGroups").sum().reset_index()[['NumSymGroups', 'Count']]
    tot_mols = df['Count'].sum()
    df['Perc_Mols'] = df['Count'].map(lambda x: f"{x/tot_mols:.2%}")
    logger.info(f"FCP -- RESULTS FOR FCP:\n\n{df}\n")

    return df


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script for reporting the FCP step, which consists in
    attributing Fragment Combination Point to fragments.

    This script takes to inputs:
        - the data folder where FCP results are stored
        - the output folder where to save the report (log, csv and plot files)

    Some parameters can be used for adjusting the reporting, type report_fcp -h
    to display them.

    This script uses the installed npfc libary in your favorite env manager.

    Example:

        >>> report_fcp fc/04_synthetic/chembl/data/prep/05_fcp/data fc/04_synthetic/chembl/data/prep/05_fcp/report -d "My Data Set" --p dataset -c blue

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description="Compute all required files for analyzing FCC results", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('wd', type=str, default=None, help="Working directory where the data to parse is")
    parser.add_argument('wd_out', type=str, default=None, help="Output directory")
    parser.add_argument('-d', '--dataset', type=str, default=None, help="Dataset name for using in the csv/png outputs in the report folder.")
    parser.add_argument('-p', '--prefix', type=str, default=None, help="Prefix used for output files in the data/log folders.")
    parser.add_argument('-c', '--color', type=str, default='black', help="Color to use for plots.")
    parser.add_argument('--plotformat', type=str, default='svg', help="Format to use for plots. Possible values are 'svg' and 'png'.")
    parser.add_argument('--csv', type=str, default=False, help="Generate only CSV output files")
    parser.add_argument('--clear', type=str, default=False, help="Force the generation of log, plot and CSV files by clearing all report files at any found specified levels.")
    parser.add_argument('--regenplots', type=str, default=False, help="Force the geeration of plots by clearing any pre-existing plot at any specified levels.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # check arguments

    # I/O
    utils.check_arg_input_dir(args.wd)
    utils.check_arg_output_dir(args.wd_out)
    wd = Path(args.wd)
    wd_out = Path(args.wd_out)

    # prefix
    if args.prefix is None:
        logging.warning("PREFIX IS NOT SET, RESORTING TO WD DIRNAME.")
        prefix = Path(args.wd).name
    else:
        prefix = args.prefix

    # dataset
    if args.dataset is None:
        logging.warning("DATASET IS NOT SET, USING PREFIX INSTEAD.")
        dataset = prefix
    else:
        dataset = args.dataset

    # plotformat
    if args.plotformat not in ('svg', 'png'):
        raise ValueError(f"ERROR! UNKNOWN PLOT FORMAT! ('{args.plotformat}')")


    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INIT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    # logging
    global logger  # desperate attempt

    pd.options.mode.chained_assignment = None  # disable pd.io.pytables.SettingWithCopyWarning
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.INFO)
    log_file = f"{args.wd_out}/report_fcp_{prefix}.log"
    utils.check_arg_output_file(log_file)
    logger = utils._configure_logger(log_level=args.log, log_file=log_file, logger_name=log_file)

    # display rendering
    report.init_report_globals()
    pad_title = 80
    pad = 60
    color = report.DEFAULT_PALETTE.get(args.color, args.color)

    # display infos
    logger.info("LIBRARY VERSIONS")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS")
    logger.info("WD_in".ljust(pad) + f"{args.wd}")
    logger.info("WD_out".ljust(pad) + f"{args.wd_out}")
    logger.info("log_file".ljust(pad) + f"{log_file}")
    logger.info("color".ljust(pad) + f"{color}")
    logger.info("dataset".ljust(pad) + f"{dataset}")
    logger.info("prefix".ljust(pad) + f"{prefix}")
    logger.info("clear".ljust(pad) + f"{args.clear}")
    logger.info("regenplots".ljust(pad) + f"{args.regenplots}")
    logger.info("plotformat".ljust(pad) + f"{args.plotformat}")
    logger.info("log".ljust(pad) + f"{args.log}")
    d0 = datetime.now()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


    report.print_title("REPORTING FCP", 3, pad_title)

    # parse wd
    df_fcp_nsymgroups = get_df_fcp(wd)
    d1 = datetime.now()

    # define csv outputs
    output_csv_fcp_nsymgroups = f"{wd_out}/data/{prefix}_fcp_nsymgroups.csv"
    save.file(df_fcp_nsymgroups, output_csv_fcp_nsymgroups)
    logger.info("FCP -- OUTPUT_CSV_FCP_NSYMGROUPS".ljust(pad) + f"{output_csv_fcp_nsymgroups}")
    d2 = datetime.now()

    if not args.csv:
        Path(f"{wd_out}/plot").mkdir(exist_ok=True)

        # define plot outputs
        output_plot_fcp_nsymgroups = f"{wd_out}/plot/{prefix}_fcp_nsymgroups.{args.plotformat}"
        logger.info("FCP -- OUTPUT PLOT FILE SYNTAX: OUTPUT_CSV.PLOTFORMAT")

        # plot output_plot_fcp_nsymgroups
        logger.info("FCP -- OUTPUT_PLOT_FCP_NSYMGROUPS".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_fcp_nsymgroups,
                    output_plot_fcp_nsymgroups,
                    'NumSymGroups',
                    'Count',
                    f"Number of Symmetry Groups in {dataset}",
                    x_label='NumSymGroups',
                    y_label='Count',
                    color=color,
                    perc_labels='Perc_Mols',
                    )

    d3 = datetime.now()

    # end
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: ITERATING OVER CHUNKS".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: EXPORTING CSV".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: EXPORTING PLOTS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d3-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
