#!python

"""
Script fc_classify
==========================
This script is used for classifying fragment combinations (fc) into
fragment combination categories (fcc).
"""

# standard
from pathlib import Path
import warnings
import sys
import os
from datetime import datetime
import logging
import pandas as pd
import argparse
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils
from npfc import fragment_combination
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script used for classifying fragment combinations.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> classify_frags_comb file_mols.sdf file_sub.csv.gz file_fcc.csv.gz

    N.B. For now the sub and output files have hard-coded formats (.csv.gz or .hdf).
    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_sub', type=str, default=None, help="Input file for substructures hits.")
    parser.add_argument('output_fcc', type=str, default=None, help="Output file basename. It gets appended with the type of output being produced: raw, clean and map.")
    parser.add_argument('-c', '--cutoff', type=int, default=3, help="Threshold value on intermediary atoms between 2 fragments. Any value combination with a higher number than this value will be rejected and flagged as false positive (cfc)")
    parser.add_argument('-e', '--exclude-exocyclic', type=bool, default=True, help="Exclude exocyclic atoms in fragments during the fragment combination classifcation (no more methyls turning connections into fusions).")
    parser.add_argument('--keep-raw', type=bool, default=False, help="Also save raw fragment combination classification (fcc) in addition to clean fcc (without cfc false positives)")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING FRAGMENT COMBINATION CLASSIFICATION")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments
    utils.check_arg_input_file(args.input_sub)
    utils.check_arg_output_file(args.output_fcc)
    utils.check_arg_positive_number(args.cutoff)
    sub_format, sub_compression = utils.get_file_format(args.input_sub)
    out_format, out_compression = utils.get_file_format(args.output_fcc)
    output_fcc_path = Path(args.output_fcc)
    output_raw = str(output_fcc_path.parent) + "/" + output_fcc_path.stem.split(".")[0] + "_raw" + "".join(output_fcc_path.suffixes)

    # display infos
    try:
        logger.info("CURRENT ENV".ljust(pad) + f"{os.environ['CONDA_DEFAULT_ENV']}")
    except KeyError:
        logger.info("CURRENT ENV".ljust(pad) + "UNDEFINED")
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT_SUB".ljust(pad) + f"{args.input_sub}")
    logger.info("SUB_FORMAT".ljust(pad) + f"{sub_format}")
    logger.info("SUB_COMPRESSION".ljust(pad) + f"{sub_compression}")
    logger.info("OUTPUT_FCC".ljust(pad) + f"{args.output_fcc}")
    if args.keep_raw:
        logger.info("OUTPUT_RAW".ljust(pad) + f"{output_raw}")
    logger.info("OUTPUT_FCC".ljust(pad) + f"{args.output_fcc}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("CUTOFF".ljust(pad) + f"{args.cutoff}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    # load sub
    logger.info("LOADING SUBSTRUCTURE MATCHES")
    d1 = datetime.now()
    df_aidx = load.file(args.input_sub,
                        keep_props=True,
                        )
    logger.info(f"FOUND {len(df_aidx.index)} SUBSTRUCTURE MATCHES")

    # run fcc
    logger.info("CLASSIFYING FRAGMENT COMBINATIONS")
    d2 = datetime.now()
    if not args.keep_raw:
        d3 = datetime.now()
        df_fcc = fragment_combination.classify_df(df_aidx, exclude_exocyclic=args.exclude_exocyclic)
        logger.info(f"FOUND {len(df_fcc.index)} COMBINATIONS")
        d5 = datetime.now()
    else:
        d3 = datetime.now()
        df_fcc = fragment_combination.classify_df(df_aidx, clear_cfc=False, exclude_exocyclic=args.exclude_exocyclic)
        logger.info(f"FOUND {len(df_fcc.index)} COMBINATIONS")

        # save raw results
        logger.info(f"SAVING RAW RESULTS AT '{output_raw}'")
        d4 = datetime.now()
        save.file(df_fcc, output_raw)

        # run cleaning
        logger.info("CLEANING COMBINATIONS")
        d5 = datetime.now()
        df_fcc = df_fcc[df_fcc['abbrev'] != 'cfc']
        logger.info(f"NUMBER OF FRAGMENT COMBINATIONS REMAINING: {len(df_fcc.index)}")
    logging.warning("OVERLAPPING FRAGMENTS ARE STILL CONSIDERED A FRAGMENT COMBINATION HERE,")
    logging.warning("THIS IS NEEDED FOR ALTERNATIVE MAPPING FRAGMENT COMBINATIONS.")

    # save clean results
    logger.info(f"SAVING CLEAN RESULTS AT '{args.output_fcc}'")
    d6 = datetime.now()
    save.file(df_fcc, args.output_fcc)

    # end
    d7 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING SUBSTRUCTURE MATCHES".ljust(pad * 2) + f"{d2-d1}")
    if args.keep_raw:
        logger.info("COMPUTATIONAL TIME: RUNNING RAW CLASSIFICATION".ljust(pad * 2) + f"{d4-d3}")
        logger.info("COMPUTATIONAL TIME: SAVING RAW COMBINATIONS".ljust(pad * 2) + f"{d5-d4}")
        logger.info("COMPUTATIONAL TIME: CLEANING COMBINATIONS".ljust(pad * 2) + f"{d6-d5}")
    else:
        logger.info("COMPUTATIONAL TIME: RUNNING CLASSIFICATION".ljust(pad * 2) + f"{d5-d3}")

    logger.info("COMPUTATIONAL TIME: SAVING CLEAN COMBINATIONS".ljust(pad * 2) + f"{d6-d5}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d7-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
