#!python

"""
Script frags_annotate_fcp
==========================
This script is used for annotating Fragment Combination Points (FCP) in fragments.
"""

# standard
import warnings
import sys
import os
from datetime import datetime
import pandas as pd
import argparse
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils
from npfc import fragment_combination_point
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script used for annotating Fragment Combination Points (FCP) in fragments.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> frags_annotate frags_passed.csv.gz frags_fcp.csv.gz

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_passed', type=str, default=None, help="Input file for standardized fragments.")
    parser.add_argument('output_fcp', type=str, default=None, help="Output file for annotated fragments.")
    parser.add_argument('-c', '--count-file', type=str, default=None, help="Output file for recording the counts of symmetry groups per fragment.")
    parser.add_argument('--fcp', type=str, default='_fcp_labels', help="Column name in the fragments DataFrame to use for labelling fragment connection points.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING FRAGMENT COMBINATION CLASSIFICATION")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments
    utils.check_arg_input_file(args.input_passed)
    utils.check_arg_output_file(args.output_fcp)
    in_format, in_compression = utils.get_file_format(args.input_passed)
    out_format, out_compression = utils.get_file_format(args.output_fcp)

    if args.count_file is not None:
        utils.check_arg_output_file(args.count_file)
        count_format, count_compression = utils.get_file_format(args.count_file)

    # display infos
    try:
        logger.info("CURRENT ENV".ljust(pad) + f"{os.environ['CONDA_DEFAULT_ENV']}")
    except KeyError:
        logger.info("CURRENT ENV".ljust(pad) + "UNDEFINED")
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT".ljust(pad) + f"{args.input_passed}")
    logger.info("INPUT_FORMAT".ljust(pad) + f"{in_format}")
    logger.info("INPUT_COMPRESSION".ljust(pad) + f"{in_compression}")
    logger.info("OUTPUT".ljust(pad) + f"{args.output_fcp}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("COUNT_FILE".ljust(pad) + f"{args.count_file}")
    if args.count_file is not None:
        logger.info("COUNT_FORMAT".ljust(pad) + f"{out_format}")
        logger.info("COUNT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("FCP".ljust(pad) + f"{args.fcp}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    # load fragments
    logger.info("LOADING FRAGMENTS")
    d1 = datetime.now()
    df_frags = load.file(args.input_passed)
    logger.info(f"FOUND {len(df_frags.index)} FRAGMENTS")

    # run FCP annotation
    logger.info("ANNOTATING FRAGMENT CONNECTION POINTS")
    d2 = datetime.now()
    logger.debug(f"INPUT FRAGMENTS:\n\n{df_frags}\n")
    df_frags[args.fcp] = df_frags['mol'].map(fragment_combination_point.get_fcp_labels)

    # analyze a bit the results
    logger.info("ANALYZING RESULTS")
    d3 = datetime.now()

    # number of symmetry groups per frag
    df_frags['num_symmetry_groups'] = df_frags[args.fcp].map(fragment_combination_point.count_symmetry_groups)

    # number of fragments with symmetry groups
    logger.info(f"NUMBER OF FRAGMENTS WITH SYMMETRY GROUPS: {len(df_frags[df_frags['num_symmetry_groups'] > 0]):,}")

    # counts
    df_frags_counts = df_frags[['idm', 'num_symmetry_groups']].groupby('num_symmetry_groups').count().reset_index().rename({'idm': 'Count', 'num_symmetry_groups': 'NumSymmetryGroups'}, axis=1)
    tot = df_frags_counts['Count'].sum()
    df_frags_counts['Perc'] = df_frags_counts['Count'] / tot
    # save the actual ratio
    if args.count_file is not None:
        save.file(df_frags_counts, args.count_file)
    # display with nice percentages
    df_frags_counts['Perc'] = df_frags_counts['Perc'].map(lambda x: f"{x:2%}")
    logger.info(f"RESULT -- NUMBER OF SYMMETRY GROUPS PER FRAGMENT:\n\n{df_frags_counts}\n")

    # save results
    logger.info(f"SAVING RESULTS AT '{args.output_fcp}'")
    d4 = datetime.now()
    save.file(df_frags, args.output_fcp)

    # end
    d5 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING FRAGMENTS".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: RUNNING FCP ANNOTATION".ljust(pad * 2) + f"{d2-d3}")
    logger.info("COMPUTATIONAL TIME: ANALYZING RESULTS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: SAVING RESULTS".ljust(pad * 2) + f"{d5-d4}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d5-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
