#!python

"""
Script map_frags
==========================
This script is used for generating fragment maps from fragment combination classifications.
"""

# standard
import warnings
import sys
from datetime import datetime
import pandas as pd
import argparse
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import fragment_combination_graph
from npfc import load
from npfc import save
from npfc import utils
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script used for genearting fragment combination graphs.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> map_frags file_fcc.csv.gz file_map.csv.gz

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_fcc', type=str, default=None, help="Input file for fragment combinations.")
    parser.add_argument('output_map', type=str, default=None, help="Output file basename. It gets appended with the type of output being produced: raw, clean and map.")
    parser.add_argument('-f', '--filter-frags', type=str, default=None, help="List of IDs to filter out from the input fragment combinations.")
    parser.add_argument('--min-frags', type=int, default=2, help="Maximum number of fragments allowed for computing a graph. Molecules with a lower number are filtered out.")
    parser.add_argument('--max-frags', type=int, default=9999, help="Maximum number of fragments allowed for computing a graph. Molecules with a higher number are filtered out.")
    parser.add_argument('--max-overlaps', type=int, default=5, help="Minimum number of overlapping fragment combinations for processing a molecule. Molecules with a higher number are filtered out.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING FRAGMENT MAPPING")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments
    utils.check_arg_input_file(args.input_fcc)
    utils.check_arg_output_file(args.output_map)
    utils.check_arg_positive_number(args.min_frags)
    utils.check_arg_positive_number(args.max_frags)
    utils.check_arg_positive_number(args.max_overlaps)
    fcc_format, fcc_compression = utils.get_file_format(args.input_fcc)
    out_format, out_compression = utils.get_file_format(args.output_map)
    filter_frags = args.filter_frags
    if args.filter_frags is None:
        filter_frags = []
    elif ',' in filter_frags:
        filter_frags = tuple([x.strip() for x in filter_frags.split(',')])
    else:
        filter_frags = tuple([filter_frags])

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT_FCC".ljust(pad) + f"{args.input_fcc}")
    logger.info("FCC_FORMAT".ljust(pad) + f"{fcc_format}")
    logger.info("FCC_COMPRESSION".ljust(pad) + f"{fcc_compression}")
    logger.info("OUTPUT_MAP".ljust(pad) + f"{args.output_map}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("MIN_FRAGS".ljust(pad) + f"{args.min_frags}")
    logger.info("MAX_FRAGS".ljust(pad) + f"{args.max_frags}")
    logger.info("MAX_OVERLAPS".ljust(pad) + f"{args.max_overlaps}")
    logger.info('FILTER_FRAGS'.ljust(pad) + f"{', '.join(filter_frags)}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("RUNNING FRAGMENT GRAPH GENERATION")

    # load fcc
    logger.info("LOADING FRAGMENT COMBINATIONS")
    d1 = datetime.now()
    df_fcc = load.file(args.input_fcc)
    logger.info(f"FOUND {len(df_fcc.index)} FRAGMENT COMBINATIONS")
    if len(filter_frags) > 0:
        logger.info(f"FILTERING FRAGMENT COMBINATIONS WITH FRAG(S): {', '.join(filter_frags)}")
        df_fcc = df_fcc[~df_fcc['idf1'].isin(filter_frags)]
        df_fcc = df_fcc[~df_fcc['idf2'].isin(filter_frags)]

    # mapping fragment combinations
    logger.info("MAPPING REMAINING FRAGMENT COMBINATIONS")
    d2 = datetime.now()
    df_map = fragment_combination_graph.generate(df_fcc, min_frags=args.min_frags, max_frags=args.max_frags, max_overlaps=args.max_overlaps)
    logger.info(f"COMPUTED {len(df_map.index)} FRAGMENT COMBINATION NETWORKS")

    # save map results
    logger.info(f"SAVING MAP RESULTS AT '{args.output_map}'")
    d3 = datetime.now()
    save.file(df_map, args.output_map, encode=True)

    # end
    d4 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING FRAGMENT COMBINATIONS".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: MAPPING COMBINATIONS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: SAVING FRAGMENT COMBINATIONS NETWORKS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d4-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
