#!/usr/bin/env python

"""
Script report_prep
==================
This script is used for reporting the prep steps.
"""

# standard
import warnings
import logging
import argparse
import sys
from shutil import rmtree
from datetime import datetime
import re
from pathlib import Path
from collections import OrderedDict
# data handling
import numpy as np
import json
import pandas as pd
from pandas import DataFrame
import networkx as nx
# data visualization
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
# from pylab import savefig
from adjustText import adjust_text
# chemoinformatics
import rdkit
from rdkit import Chem
from rdkit.Chem import Mol
# docs
from typing import List
from typing import Tuple
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import utils
from npfc import load
from npfc import save
from npfc import draw
from npfc import report
from multiprocessing import Pool
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def _parse_std_chunk(c):
    return load.file(c, decode=False).groupby("task").count()[['status']].rename({'status': 'Count'}, axis=1)


def _parse_std_chunks(chunks: List[str]) -> DataFrame:
    """Parse all output files of a category (passed, filtered or error) for the std step and return a corresponding a results summary.

    :param chunks: output files for a category of std results
    :return: summary DF with counts
    """
    # parse all files
    dfs = []
    pool = Pool()
    dfs = pool.map(_parse_std_chunk, chunks)
    pool.close()
    pool.join()

    dfs = [df for df in dfs if len(df) > 0]
    # if no case was found, return an empty dataframe
    if len(dfs) == 0:
        df = pd.DataFrame([], columns=["Count", "Category"])
        return df

    # concatenate all dataframes and compute the sum of all counts
    df = pd.concat(dfs)
    df["Category"] = df.index  # I don't know how to group by index!
    df = df.groupby("Category").sum()
    df["Category"] = df.index.map(lambda x: x.replace('filter_', ''))
    # df['Perc_status'] = df['Count'].map(lambda x: f"{x/tot_mols:.2%}")

    return df.reset_index(drop=True)


def parse_chunk_load(c: str):
    df = pd.read_csv(c, sep="@", header=None)
    records = df[df[0].str.contains("FAILURE")].iloc[0][0].split()
    errors = int(records[9])
    passed = int(df[df[0].str.contains("SAVED")].iloc[0][0].split()[6])

    return (passed, errors)


def get_df_load(WD: str) -> DataFrame:
    """Get a DF summarizing the load step.

    :param WD: the directory of the load step
    :return: a DF summarizing the load step
    """
    logger.info("PREP -- COMPUTING LOAD RESULTS")
    # iterate over the log files to count status
    pattern = ".*([0-9]{3})?.log"
    chunks = report._get_chunks(f"{WD}/log", pattern)
    logger.info(f"FOUND {len(chunks):,d} CHUNKS")

    pool = Pool()
    results = pool.map(parse_chunk_load, chunks)
    pool.close()
    pool.join()

    # sum of tuples
    df = pd.DataFrame(results, columns=['passed', 'errors'])
    num_passed = df.passed.sum()
    num_errors = df.errors.sum()

    # create a dataframe with counts
    df = pd.DataFrame({'Category': ['loaded', 'cannot_load'], 'Count': [num_passed, num_errors]})
    logger.info(f"PREP -- RESULTS FOR LOADING MOLECULES:\n\n{df}\n")

    return df


def parse_chunk_std_passed(c: str):
    return len(load.file(c, decode=False).index)


def get_df_std_passed(WD: str) -> DataFrame:
    """Get a DF summarizing the passed results of the sandardization step.

    :param WD: the directory of the std step
    :return: a DF summarizing passed results of the sandardization step
    """
    logger.info("PREP -- COMPUTING STD PASSED RESULTS")
    # iterate over the log files to count status
    pattern = ".*([0-9]{3})?_std.csv.gz"
    chunks = report._get_chunks(f"{WD}/data", pattern)
    logger.info(f"PREP -- FOUND {len(chunks):,d} CHUNKS")
    # initiate counts
    pool = Pool()
    results = pool.map(parse_chunk_std_passed, chunks)
    pool.close()
    pool.join()

    # gather data
    num_passed = sum(results)
    df = pd.DataFrame({"Category": ['passed'], 'Count': [num_passed]})
    logger.info(f"PREP -- RESULTS FOR STD PASSED:\n\n{df}\n")

    return df


def get_df_std_filtered(WD: str) -> DataFrame:
    """Get a DF summarizing the filtered results of the sandardization step.

    :param WD: the directory of the std step
    :return: a DF summarizing filtered results of the sandardization step
    """
    logger.info("PREP -- COMPUTING STD FILTERED RESULTS")
    # iterate over the log files to count status
    pattern = ".*([0-9]{3})?_filtered.csv.gz"
    chunks = report._get_chunks(f"{WD}/log", pattern)
    logger.info(f"PREP -- FOUND {len(chunks):,d} CHUNKS")
    # initiate counts
    df = _parse_std_chunks(chunks)
    logger.info(f"RESULTS FOR STD FILTERED:\n\n{df}\n")

    return df


def get_df_std_error(WD: str) -> DataFrame:
    """Get a DF summarizing the error results of the sandardization step.

    :param WD: the directory of the std step
    :return: a DF summarizing error results of the sandardization step
    """
    logger.info("PREP -- COMPUTING STD ERROR RESULTS")
    # iterate over the log files to count status
    pattern = ".*([0-9]{3})?_error.csv.gz"
    chunks = report._get_chunks(f"{WD}/log", pattern)
    logger.info(f"PREP -- FOUND {len(chunks):,d} CHUNKS")
    # initiate counts
    df = _parse_std_chunks(chunks)
    logger.info(f"PREP -- RESULTS FOR STD ERROR:\n\n{df}\n")

    return df


def parse_chunk_dedupl(c):

    df = pd.read_csv(c, sep="@", header=None)  # char not found in the log file so we can extract all lines as one column
    passed, total = [int(x) for x in df[df[0].str.contains("REMAINING MOLECULES")].iloc[0][0].split("MOLECULES:")[1].split("/")]
    filtered = total - passed
    return (passed, filtered)


def get_df_dedupl(WD: str) -> DataFrame:
    """Get a DF summarizing the results of the deduplication step.

    :param WD: the directory of the std step
    :return: a DF summarizing results of the deduplication step
    """
    logger.info("PREP -- COMPUTING DEDUPL RESULTS")
    # iterate over the log files to count status
    pattern = ".*([0-9]{3})?_dedupl.log"
    chunks = report._get_chunks(f"{WD}/log", pattern)
    chunks = [c for c in chunks if c.split('.')[-1] == 'log']  # ###### quick and dirty
    # print(f"{chunks=}")
    logger.info(f"PREP -- FOUND {len(chunks):,d} CHUNKS")
    # initiate counts
    num_passed = 0
    num_filtered = 0

    # initiate counts
    pool = Pool()
    results = pool.map(parse_chunk_dedupl, chunks)
    pool.close()
    pool.join()

    # sum of tuples
    df = pd.DataFrame(results, columns=['passed', 'filtered'])
    num_passed = df.passed.sum()
    num_filtered = df.filtered.sum()

    # create a dataframe with counts
    df = pd.DataFrame({'Category': ['unique', 'duplicate'], 'Count': [num_passed, num_filtered]})
    logger.info(f"PREP -- RESULTS FOR DEDUPL:\n\n{df}\n")

    return df


def get_dfs_prep(WD: str) -> Tuple[DataFrame]:
    """Get a list of DFs summarizing the whole preprocess superstep: load, std and dedupl.

    - DF_prep_filtered is the detailed summary of std and dedupl
    - DF_prep_error is the detailed summary of std and load
    - DF_prep_all is the general summary with the final number of passed, filtered and error molecules.

    :param WD: the main directory of the dataset data (i.e. 'natural/dnp/data')
    :return: a list of DFs of interest: [DF_prep_filtered, DF_prep_error, DF_prep_all]
    """

    logger.info("PREP -- COMPUTE RESULTS FOR PREPROCESS")
    logger.info("PREP -- PROPRESS CONTAINS LOAD, DEGLYCO, STD AND DEDUPL STEPS")

    # define subfolders
    p = Path(WD)
    WD_LOAD = [str(x) for x in list(p.glob("*_load"))][0]
    WD_STD = [str(x) for x in list(p.glob("*_std"))][0]
    WD_DEDUPL = [str(x) for x in list(p.glob("*_dedupl"))][0]

    # get dfs
    df_load = get_df_load(WD_LOAD)
    df_std_passed = get_df_std_passed(WD_STD)
    df_std_filtered = get_df_std_filtered(WD_STD)
    df_std_error = get_df_std_error(WD_STD)
    df_dedupl = get_df_dedupl(WD_DEDUPL)

    # get total of molecules in input
    num_mols_tot = df_load['Count'].sum()
    logger.info(f"PREP -- TOTAL NUMBER OF MOLECULES: {num_mols_tot:,}")
    # num_tot_passed = df_dedupl[df_dedupl['Category'] == 'unique']['Count'].sum()
    # gather all filtered molecules
    df_dedupl_dupl = df_dedupl[df_dedupl['Category'] == 'duplicate']
    num_dedupl_dupl = df_dedupl_dupl['Count'].sum()
    df_std_filtered = pd.concat([df_std_filtered, df_dedupl_dupl], sort=True)
    # count even unoccurred cases in df_std_filtered
    filters = ['empty', 'num_heavy_atoms', 'molecular_weight', 'num_rings', 'elements', 'timeout', 'unwanted', 'duplicate']
    df_std_filtered.set_index('Category', inplace=True)
    df_std_filtered = df_std_filtered.reindex(filters)
    df_std_filtered.reset_index(inplace=True)
    df_std_filtered.fillna(0, inplace=True)
    df_std_filtered['Count'] = df_std_filtered['Count'].astype(int)
    df_std_filtered['Perc_Status'] = df_std_filtered['Count'].map(lambda x: f"{x/num_mols_tot:.2%}")
    logger.info(f"PREP -- RESULTS FOR STD_FILTERED:\n\n{df_std_filtered}\n")

    # gather all molecules that raised an error
    df_std_error = pd.concat([df_std_error, df_load[df_load['Category'] == 'cannot_load']], sort=True)
    # count even unoccurred cases in df_std_error
    errors = ['cannot_load', 'initiate_mol', 'disconnect_metal', 'sanitize', 'clear_isotopes', 'normalize', 'uncharge', 'canonicalize', 'clear_stereo', 'empty_final']
    df_std_error.set_index('Category', inplace=True)
    df_std_error = df_std_error.reindex(errors)
    df_std_error.reset_index(inplace=True)
    df_std_error.fillna(0, inplace=True)
    df_std_error['Count'] = df_std_error['Count'].astype(int)
    df_std_error['Perc_Status'] = df_std_error['Count'].map(lambda x: f"{x/num_mols_tot:.2%}")
    logger.info(f"PREP -- RESULTS FOR STD_ERRORS:\n\n{df_std_error}\n")

    # general count for passed/filtered/errors
    num_tot_filtered = df_std_filtered['Count'].sum()
    num_tot_passed = df_std_passed['Count'].sum() - num_dedupl_dupl  # dedupl happens after std, so std contains passsed mols that get filtered
    num_tot_errors = df_std_error['Count'].sum()


    df_prep_overview = pd.DataFrame({'Category': ['passed', 'filtered', 'errors'], 'Count': [num_tot_passed, num_tot_filtered, num_tot_errors]})
    df_prep_overview['Perc_Status'] = df_prep_overview['Count'].map(lambda x: f"{x/num_mols_tot:.2%}")
    logger.info(f"PREP -- RESULTS FOR STD_ALL:\n\n{df_prep_overview}\n")

    return {'df_prep_overview': df_prep_overview, 'df_prep_filtered': df_std_filtered, 'df_prep_error': df_std_error}


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script for reporting the preparation step.

    Preparation steps include:
        - load
        - standardize
        - deduplicate

    This script takes two inputs:
        - the root folder where all preparation substeps are located, i.e.:

        fc/
        ├── 03_natural/
        │   ├── coconut
                ├── data
                    ├── prep  <- root folder for preparation step
                        ├── 01_chunk
                        ├── 02_load
                        ├── 03_standardize
                        ├── 04_deduplicate
                        ├── 05_depict
                        ├── frags_crms
                            ├── 06_fs
                            ├── 07_fcc
                            ├── 08_fcg

        - the output folder where to save the report (log, csv and plot files)


    Some parameters can be used for adjusting the reporting, type report_prep -h
    to display them.

    This script uses the installed npfc libary in your favorite env manager.

    Example:

        >>> report_prep fc/04_synthetic/chembl/data/prep fc/04_synthetic/chembl/data/prep/report -d "My Data Set" --p dataset -c blue

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description="Compute all required files for analyzing FCC results", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('wd', type=str, default=None, help="Working directory where the data to parse is")
    parser.add_argument('wd_out', type=str, default=None, help="Output directory")
    parser.add_argument('-d', '--dataset', type=str, default=None, help="Dataset name for using in the csv/png outputs in the report folder.")
    parser.add_argument('-p', '--prefix', type=str, default=None, help="Prefix used for output files in the data/log folders.")
    parser.add_argument('-c', '--color', type=str, default='black', help="Color to use for plots.")
    parser.add_argument('--plotformat', type=str, default='svg', help="Format to use for plots. Possible values are 'svg' and 'png'.")
    parser.add_argument('--csv', type=str, default=False, help="Generate only CSV output files")
    parser.add_argument('--clear', type=str, default=False, help="Force the generation of log, plot and CSV files by clearing all report files at any found specified levels.")
    parser.add_argument('--regenplots', type=str, default=False, help="Force the geeration of plots by clearing any pre-existing plot at any specified levels.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # check arguments

    # I/O
    utils.check_arg_input_dir(args.wd)
    utils.check_arg_output_dir(args.wd_out)
    wd = Path(args.wd)
    wd_out = Path(args.wd_out)

    # prefix
    if args.prefix is None:
        logging.warning("PREFIX IS NOT SET, RESORTING TO WD DIRNAME.")
        prefix = Path(args.wd).name
    else:
        prefix = args.prefix

    # dataset
    if args.dataset is None:
        logging.warning("DATASET IS NOT SET, USING PREFIX INSTEAD.")
        dataset = prefix
    else:
        dataset = args.dataset

    # plotformat
    if args.plotformat not in ('svg', 'png'):
        raise ValueError(f"ERROR! UNKNOWN PLOT FORMAT! ('{args.plotformat}')")


    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INIT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

    # logging
    global logger  # desperate attempt

    pd.options.mode.chained_assignment = None  # disable pd.io.pytables.SettingWithCopyWarning
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.INFO)
    log_file = f"{args.wd_out}/report_prep_{prefix}.log"
    utils.check_arg_output_file(log_file)
    logger = utils._configure_logger(log_level=args.log, log_file=log_file, logger_name=log_file)

    # display rendering
    report.init_report_globals()
    pad_title = 80
    pad = 60
    color = report.DEFAULT_PALETTE.get(args.color, args.color)

    # display infos
    logger.info("LIBRARY VERSIONS")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS")
    logger.info("WD_in".ljust(pad) + f"{args.wd}")
    logger.info("WD_out".ljust(pad) + f"{args.wd_out}")
    logger.info("log_file".ljust(pad) + f"{log_file}")
    logger.info("color".ljust(pad) + f"{color}")
    logger.info("dataset".ljust(pad) + f"{dataset}")
    logger.info("prefix".ljust(pad) + f"{prefix}")
    logger.info("clear".ljust(pad) + f"{args.clear}")
    logger.info("regenplots".ljust(pad) + f"{args.regenplots}")
    logger.info("plotformat".ljust(pad) + f"{args.plotformat}")
    logger.info("log".ljust(pad) + f"{args.log}")
    d0 = datetime.now()


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BEGIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


    report.print_title("REPORTING PREP", 3, pad_title)

    # parse relevant prep subdirectories
    dfs_prep = get_dfs_prep(wd)
    df_prep_overview = dfs_prep['df_prep_overview']
    df_prep_filtered = dfs_prep['df_prep_filtered']
    df_prep_error = dfs_prep['df_prep_error']
    d1 = datetime.now()

    # define csv outputs
    output_csv_prep_overview = f"{wd_out}/data/{prefix}_prep_overview.csv"
    output_csv_prep_filtered = f"{wd_out}/data/{prefix}_prep_filtered.csv"
    output_csv_prep_error = f"{wd_out}/data/{prefix}_prep_error.csv"
    save.file(df_prep_overview, output_csv_prep_overview)
    save.file(df_prep_filtered, output_csv_prep_filtered)
    save.file(df_prep_error, output_csv_prep_error)
    logger.info("PREP -- OUTPUT_CSV_PREP_OVERVIEW".ljust(pad) + f"{output_csv_prep_overview}")
    logger.info("PREP -- OUTPUT_CSV_PREP_FILTERED".ljust(pad) + f"{output_csv_prep_filtered}")
    logger.info("PREP -- OUTPUT_CSV_PREP_ERROR".ljust(pad) + f"{output_csv_prep_error}")
    d2 = datetime.now()

    # define plot outputs
    output_plot_prep_overview = f"{wd_out}/plot/{prefix}_prep_overview.{args.plotformat}"
    output_plot_prep_filtered = f"{wd_out}/plot/{prefix}_prep_filtered.{args.plotformat}"
    output_plot_prep_error = f"{wd_out}/plot/{prefix}_prep_error.{args.plotformat}"

    logger.info("PREP -- OUTPUT PLOT FILE SYNTAX: OUTPUT_CSV.PLOTFORMAT")

    if not args.csv:
        Path(f"{wd_out}/plot").mkdir(exist_ok=True)

        # plot output_plot_prep_overview
        logger.info("PREP -- OUTPUT_PLOT_PREP_OVERVIEW".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_prep_overview,
                            output_plot_prep_overview,
                            'Category',
                            'Count',
                            f"Preparation of Molecules in {dataset} - Overview",
                            x_label='Category',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc_Status',
                            fig_size=(12, 12),
                            )
        # plot output_plot_prep_filtered
        logger.info("PREP -- OUTPUT_PLOT_PREP_FILTERED".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_prep_filtered,
                            output_plot_prep_filtered,
                            'Category',
                            'Count',
                            f"Preparation of Molecules in {dataset} - Filtered",
                            x_label='Category',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc_Status'
                            # rotate_x=90,
                            )

        # plot output_plot_prep_error
        logger.info("PREP -- OUTPUT_PLOT_PREP_ERROR".ljust(pad) + "COMPUTING...")
        report.save_barplot(df_prep_error,
                            output_plot_prep_error,
                            'Category',
                            'Count',
                            f"Preparation of Molecules in {dataset} - Error",
                            x_label='Category',
                            y_label='Count',
                            color=color,
                            perc_labels='Perc_Status',
                            # rotate_x=90,
                            )
    d3 = datetime.now()

    # end
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: ITERATING OVER CHUNKS".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: EXPORTING CSV".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: EXPORTING PLOTS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d3-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
