#!python

"""
Script report_time
==========================
This script is used for counting the computational time of each step for a given
chunk.
"""

# standard
import warnings
from pathlib import Path
import logging
import re
import sys
from datetime import timedelta
from datetime import datetime
import pandas as pd
import argparse
from pandas import DataFrame
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import load
from npfc import utils
import subprocess
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def _filter_input_file(input_file: str, prep_ref: str, natref_ref: str, frags_ref: str) -> bool:
    """Helper function to filter out files from other runs (prep_stereo, frags_crms_nobn, etc.)
    This function uses the initial matches by comparing the expected names of prep, natref and frags
    subdirectories.

    :param input_file: the input file
    :param prep_ref: the reference name for the prep subfolder
    :param natref_ref: the reference name for the prep natref
    :param prep_ref: the reference name for the prep frags
    :return: True if the input file passes all filters, False otherwise
    """
    idx_prep = 4
    idx_natref = 5
    idx_frags = 6
    num_levels_prep = 8
    num_levels_natref = 9
    num_levels_frags = 10
    natural = False
    if 'natural' in input_file:
        natural = True
        num_levels_frags = num_levels_natref

    levels = input_file.split('/')
    num_levels = len(levels)
    # print(f"num_levels={num_levels}")
    if num_levels < num_levels_prep:
        # print("not enough levels!")
        return False
    prep = levels[idx_prep]
    if prep != prep_ref:
        # print(f"prep={prep} != prep_ref ({prep_ref})")
        return False
    if not natural and num_levels >= num_levels_natref:
        natref = levels[idx_natref]
        if natref != natref_ref:
            print(f"natref={natref} != natref_ref ({natref_ref})")
            return False
    if num_levels == num_levels_frags:
        frags = levels[idx_frags]
        if frags != frags_ref:
            print(f"frags={frags} != frags_ref ({frags_ref})")
            return False

    return True





def _filter_input_file(input_file: str, prep_ref: str, natref_ref: str, frags_ref: str) -> bool:
    """Helper function to filter out files from other runs (prep_stereo, frags_crms_nobn, etc.)
    This function uses the initial matches by comparing the expected names of prep, natref and frags
    subdirectories.

    :param input_file: the input file
    :param prep_ref: the reference name for the prep subfolder
    :param natref_ref: the reference name for the prep natref
    :param prep_ref: the reference name for the prep frags
    :return: True if the input file passes all filters, False otherwise
    """
    idx_prep = 4
    idx_natref = 5
    idx_frags = 6
    levels = input_file.split('/')
    num_levels = len(levels)

    if prep_ref is None:
        # should not happen
        print("NOOOO")

    # fragments
    if prep_ref is not None and natref_ref is None and frags_ref is None:
        print("fragments")
        if levels[idx_prep] != prep_ref:
            return False

    # natural
    if natref_ref is None and frags_ref is not None:
        # happens for natural
        print("natural")
        idx_frags -= 1
        if levels[idx_prep] != prep_ref:
            return False
        if levels[idx_frags] != frags_ref:
            print(f"problem here: {levels[idx_frags]} != {frags_ref}")
            return False

    # synthetic / other
    if prep_ref is not None and natref_ref is not None and frags_ref is not None:
        # happens for synthetic
        print("synthetic")
        if levels[idx_prep] != prep_ref:
            return False
        if levels[idx_natref] != frags_ref:
            return False
        if levels[idx_frags] != frags_ref:
            return False


def _filter_input_file_fragments(input_file: str, prep_ref: str) -> bool:
    """Helper function to filter out files from other runs (prep_stereo, etc.)
    This function uses the initial matches by comparing the expected name of the prep subdirectory.

    :param input_file: the input file
    :param prep_ref: the reference name for the prep subfolder
    :return: True if the input file passes all filters, False otherwise
    """
    idx_prep = 4
    num_levels_prep = 8
    levels = input_file.split('/')
    num_levels = len(levels)
    if num_levels < num_levels_prep:
        logging.debug('fragments -- %s -- filtered out because too low levels (%s < %s)', input_file, num_levels, num_levels_prep)
        return False
    prep = levels[idx_prep]
    if prep != prep_ref:
        logging.debug('fragments -- %s -- filtered out because not matching prep subdir name (%s != %s)', input_file, prep, prep_ref)
        return False
    return True


def _filter_input_file_natural(input_file: str, prep_ref: str, frags_ref: str) -> bool:
    """Helper function to filter out files from other runs (prep_stereo, frags_crms_nobn, etc.)
    This function uses the initial matches by comparing the expected name of the prep subdirectory.

    :param input_file: the input file
    :param prep_ref: the reference name for the prep subfolder
    :param frags_ref: the reference name for the frags subfolder
    :return: True if the input file passes all filters, False otherwise
    """
    idx_prep = 4
    idx_frags = 5
    num_levels_prep = 8
    num_levels_frags = 9
    levels = input_file.split('/')
    num_levels = len(levels)
    if num_levels >= num_levels_prep:
        prep = levels[idx_prep]
        if prep != prep_ref:
            logging.debug('natural -- prep -- %s -- filtered out because not matching prep subdir name (%s != %s)', input_file, prep, prep_ref)
            return False
    if num_levels >= num_levels_frags:
        frags = levels[idx_frags]
        if frags != frags_ref:
            logging.debug('natural -- frags -- %s -- filtered out because not matching frags subdir name (%s != %s)', input_file, frags, frags_ref)
            return False
    return True


def _filter_input_file_synthetic(input_file: str, prep_ref: str, natref_ref: str, frags_ref: str) -> bool:
    """Helper function to filter out files from other runs (prep_stereo, frags_crms_nobn, etc.)
    This function uses the initial matches by comparing the expected name of the prep subdirectory.

    :param input_file: the input file
    :param prep_ref: the reference name for the prep subfolder
    :param frags_ref: the reference name for the frags subfolder
    :return: True if the input file passes all filters, False otherwise
    """
    idx_prep = 4
    idx_natref = 5
    idx_frags = 6
    num_levels_prep = 8
    num_levels_natref = 9
    num_levels_frags = 10
    levels = input_file.split('/')
    num_levels = len(levels)

    if num_levels >= num_levels_prep:
        prep = levels[idx_prep]
        if prep != prep_ref:
            logging.debug('synthetic -- prep -- %s -- filtered out because not matching prep subdir name (%s != %s)', input_file, prep, prep_ref)
            return False
    if num_levels >= num_levels_natref:
        natref = levels[idx_natref]
        if natref != natref_ref:
            logging.debug('synthetic -- prep -- %s -- filtered out because not matching natref subdir name (%s != %s)', input_file, natref, natref_ref)
            return False
    if num_levels >= num_levels_frags:
        frags = levels[idx_frags]
        if frags != frags_ref:
            logging.debug('synthetic -- prep -- %s -- filtered out because not matching frags subdir name (%s != %s)', input_file, frags, frags_ref)
            return False
    return True


def find_input_files(input_dir: str, pattern: str, exclude: list = ['/report/', '_fcp_symcounts.csv',  '/01_chunk/', '.gz', '.hdf'], prep: str = None, frags: str = None, natref: str = None) -> list:
    """Find input files corresponding to the pattern.

    :param input_dir: input directory where input files and subdirectories are stored.
    :param pattern: pattern to use for globbing input files. (i.e. 'chembl_001*log')
    :param exclude: a list of substrings that can be found in the input file paths to exclude some of those. (i.e. files without the groupby column, etc)
    :return: a sorted list of input files
    """
    if '/01_fragments/' in input_dir:
        mode = 'fragments'
    elif '/02_commercial/' in input_dir:
        mode = 'commercial'
    elif '/03_natural/' in input_dir:
        mode = 'natural'
    elif '/04_synthetic/' in input_dir:
        mode = 'synthetic'
    else:
        raise ValueError('ERROR! UNKNOWN MODE BECAUSE NOT IN "01_fragments", "02_commercial", "03_natural", "04_synthetic"!')

    p = Path(input_dir).rglob(pattern)
    input_files = [str(x) for x in p]
    for e in exclude:
        input_files = [x for x in input_files if e not in x]
    input_files = [x for x in input_files if Path(x).is_file()]
    num_levels = [len(x.split('/')) for x in input_files]
    input_files = [x for x, y in zip(input_files, num_levels) if y > 5]
    if mode == 'fragments' or mode == 'commercial':
        input_files = [x for x in input_files if _filter_input_file_fragments(x, prep)]
    elif mode == 'natural':
        input_files = [x for x in input_files if _filter_input_file_natural(x, prep, frags)]
    elif mode == 'synthetic':
        input_files = [x for x in input_files if _filter_input_file_synthetic(x, prep, natref, frags)]
    return sorted(input_files)


def get_time(input_log: str) -> float:
    """Parse an input log file and retrieve the computational time.
    """
    df = pd.read_csv(input_log, sep="@", header=None)
    try:
        records = df[df[0].str.contains("COMPUTATIONAL TIME: TOTAL")].iloc[0][0].split()
        record = records[8]
    except IndexError:
        logging.error("ERROR! COULD NOT PARSE FILE '%s'" % input_log)
        raise ValueError('ABORTING EXECUTION...')

    microseconds = float(f"0.{record.split('.')[1]}")
    t = datetime.strptime(record, "%H:%M:%S.%f")
    delta = timedelta(hours=t.hour, minutes=t.minute, seconds=t.second)
    return delta.total_seconds() + microseconds



def main():

    # init
    d0 = datetime.now()
    description = """Script is used for gathering computational time in each step of the pipeline
    for a given chunk using all corresponding log files.

    I takes two inputs:

        - input directory
        - pattern for matching the corresponding chunk

    It creates one output:
        - a csv file with all computational times for each step as a single row.

    All chunk computational times can then easilsy gathered manually using pandas.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> report_time fc/04_synthetic/chembl/data/prep 'chembl_001*' fc/04_synthetic/chembl/data/prep/report/chembl_001_time.csv

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_dir', type=str, default=None, help="Input directory (i.e. fc/04_synthetic/chembl/data/prep)")
    parser.add_argument('pattern', type=str, default=None, help="Pattern of the file (i.e. prefix of a chunk) to investigate. All matching files in subdirectories will be processed.")
    parser.add_argument('output_csv', type=str, default=None, help="Output file for recording the computational times.")
    parser.add_argument('-p', '--prep', type=str, default=None, help="Subfolder name for prep step.")
    parser.add_argument('-n', '--natref', type=str, default=None, help="Subfolder name for natref step.")
    parser.add_argument('-f', '--frags', type=str, default=None, help="Subfolder name for frags step.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING REPORT TIME")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments
    utils.check_arg_input_dir(args.input_dir)
    if args.input_dir.endswith('/'):
        args.input_dir = args.input_dir[0:-1]
    utils.check_arg_output_file(args.output_csv)
    out_format, out_compression = utils.get_file_format(args.output_csv)

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT_DIR".ljust(pad) + f"{args.input_dir}")
    logger.info("PATTERN".ljust(pad) + f"{args.pattern}")
    logger.info("OUTPUT_FILE".ljust(pad) + f"{args.output_csv}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("LOG".ljust(pad) + f"{args.log}")
    d1 = datetime.now()

    # begin
    logger.info("BEGIN")

    # finding the input files
    input_files = find_input_files(args.input_dir, args.pattern, prep=args.prep, natref=args.natref, frags=args.frags)
    d = {}
    logger.info("STEP".ljust(15) + "TIME".ljust(15) + "INPUT_FILE".ljust(15))
    for input_file in input_files:
        step = Path(input_file).parent.parent.name
        comput_time = get_time(input_file)
        d["pattern"] = [args.pattern]
        d[f"{step}_time"] = [comput_time]
        logger.info(f"{step}".ljust(15) + f"{comput_time}".ljust(15) + f"{input_file}".ljust(15))

    # assemble data into a DF of one single row
    df = DataFrame(d)
    d2 = datetime.now()

    # export results
    if Path(args.output_csv).exists():
        df.to_csv(args.output_csv, mode='a', sep='|', header=False, index=False)
    else:
        df.to_csv(args.output_csv, mode='w+', sep='|', header=True, index=False)

    # end
    d3 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: PARSING INPUT FILES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: EXPORTING RESULTS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d3-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
