#!python

"""
Script annotate_pnp
==========================
This script is used for annotating Pseudo Natural Products (PNP).
"""

# standard
import warnings
import sys
from datetime import datetime
import pandas as pd
import argparse
from pathlib import Path
import logging
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils
from npfc import fragment_combination_graph
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


def main():

    # init
    d0 = datetime.now()
    description = """Script is used for annotating Pseudo Natural Products (PNP).

    It takes two inputs:
        - the fcg file with the molecules to annotate by comparing fragment fcgs (graphs)
        - a folder that contains molecular files of natural products

    It creates one output:
        - the input molecular file annotated with a new column "pnp" valuing either True or False.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> annotate_pnp file_fcg.csv.gz file_fcg.csv.gz

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_fcg', type=str, default=None, help="Fragment Map file to annotate")
    parser.add_argument('ref_dir', type=str, default=None, help="Directory with Fragment Map files to use as references (currently wiht syntax: <name>_fcg.csv.gz)")
    parser.add_argument('output_fcg', type=str, default=None, help="Output file for annoated fragment graphs.")
    parser.add_argument('-d', '--data', type=str, default='fcc,fcp_1,fcp_2', help="Graph edge attributes to consider for annotating pnps (no space, separated with ',').")
    parser.add_argument('-s', '--symmetry', type=bool, default=True, help="Consider symmetry when comparing FCPs. If activated, the suffixes ('1a,2') are removed ('1,2') so equivalent FCPs are considered the same.")
    parser.add_argument('-l', '--list-pnp-mols', type=str, default=None, help="Output file with the list of molecules labelled as PNP or not. It is useful to avoid loading the whole dataset just to count the proportion of PNP/non-PNP molecules.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING PNP ANNOTATION")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments
    utils.check_arg_input_file(args.input_fcg)
    utils.check_arg_output_file(args.output_fcg)
    if args.list_pnp_mols is not None:
        utils.check_arg_output_file(args.list_pnp_mols)
    input_format, input_compression = utils.get_file_format(args.input_fcg)
    out_format, out_compression = utils.get_file_format(args.output_fcg)
    data = args.data
    if data is None:
        data = []
    elif ',' in data:
        data = tuple([x.strip() for x in data.split(',')])
    else:
        data = tuple([data])

    # get the list of reference files
    p = Path(args.ref_dir)
    if not p.is_dir():
        raise ValueError(f"ERROR! REF_DIR COULD NOT BE FOUND! ({args.ref_dir})")
    ref_files = [str(x) for x in list(p.glob("*_fcg.*"))]
    ref_files.sort()
    logger.info(f"FOUND {len(ref_files)} FRAGMENT MAP FILES AT {args.ref_dir}")

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT_FCC".ljust(pad) + f"{args.input_fcg}")
    logger.info("INPUT_FORMAT".ljust(pad) + f"{input_format}")
    logger.info("INPUT_COMPRESSION".ljust(pad) + f"{input_compression}")
    logger.info("REF_DIR".ljust(pad) + f"{args.ref_dir}")
    logger.info("NUM REF_DIR".ljust(pad) + f"{len(ref_files)}")
    logger.info("OUTPUT_MAP".ljust(pad) + f"{args.output_fcg}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("LIST_PNP_MOLS".ljust(pad) + f"{args.list_pnp_mols}")
    logger.info("PNP_ATTRIBUTES".ljust(pad) + f"{', '.join(data)}")
    logger.info("SYMMETRY".ljust(pad) + f"{args.symmetry}")
    logger.info("LOG".ljust(pad) + f"{args.log}")
    d1 = datetime.now()

    # begin
    logger.info("BEGIN")

    # load fcg
    logger.info("LOADING FRAGMENT COMBINATION GRAPHS")
    df_fcg = load.file(args.input_fcg, decode=True)
    n_fcg = len(df_fcg.index)
    logger.info(f"FOUND {n_fcg:,} FRAGMENT COMBINATION GRAPHS")
    logger.debug(f"FRAGMENT COMBINATION GRAPHS TO ANNOATE:\n{df_fcg}\n")

    d2 = datetime.now()

    # avoid crash due to groupby method on empty dataframe
    if n_fcg < 1:
        logger.warning("NO FRAGMENT COMBINATION GRAPHS TO WORK WITH, ABORTING NOW. EMPTY OUTPUT FILE IS GENERATED FOR PIPELINE HANDLING.")
        df_fcg = pd.DataFrame(columns=fragment_combination_graph.DF_PNP_COLS)
        # saving results
        logger.info("SAVING RESULTS")
        d5 = datetime.now()
        logger.info(f"SAVING PNP RESULTS AT '{args.output_fcg}'")
        save.file(df_fcg, args.output_fcg, encode=True)
        if args.list_pnp_mols is not None:
            save.file(pd.DataFrame(columns=['idm', 'idfcg', 'pnp_fcg', 'pnp_mol', '_pnp_ref']), args.list_pnp_mols)

        # end
        d6 = datetime.now()
        logger.info("SUMMARY")
        logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
        logger.info("COMPUTATIONAL TIME: LOADING INPUT FRAGMENT COMBINATION GRAPHS".ljust(pad * 2) + f"{d2-d1}")
        logger.info("COMPUTATIONAL TIME: SAVING EMPTY OUTPUT".ljust(pad * 2) + f"{d5-d2}")
        logger.info("COMPUTATIONAL TIME: DISPLAYING RESULTS".ljust(pad * 2) + f"{d6-d5}")
        logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d6-d0}")
        logger.info("END")
        return None

    # load fcg_refs
    logger.info("LOADING ALL REFERENCE FRAGMENT COMBINATION GRAPHS")
    df_fcg_ref = pd.concat([load.file(x) for x in ref_files])
    # df_fcg_ref.drop("Unnamed: 0", axis=1, inplace=True)
    logger.info(f"FOUND {len(df_fcg_ref.index):,} REFERENCE FRAGMENT COMBINATION GRAPHS")
    logger.debug(f"REFERENCE FRAGMENT COMBINATION GRAPHS:\n{df_fcg_ref}\n")

    logger.info("ANNOTATING FRAGMENT COMBINATION GRAPHS")
    d3 = datetime.now()
    df_pnp = fragment_combination_graph.annotate_pnp(df_fcg, df_fcg_ref, data=data, consider_symmetry=args.symmetry)

    # saving results
    logger.info("SAVING RESULTS")
    d4 = datetime.now()
    logger.info(f"SAVING PNP RESULTS AT '{args.output_fcg}'")
    save.file(df_pnp, args.output_fcg, encode=True)

    logger.info("EXPORTING PNP LISTS")
    d5 = datetime.now()
    df_pnp_list = df_pnp[['idm', 'idfcg', 'pnp_fcg', 'pnp_mol', '_pnp_ref']]
    if args.list_pnp_mols is not None:
        save.file(df_pnp_list, args.list_pnp_mols)
    logger.info("NUMBER OF PNP MOLS".ljust(pad) + f"{len(df_pnp_list[df_pnp_list['pnp_mol']]):,}")

    # df_pnp_examples = df_pnp[df_pnp['pnp_mol'] == True].head(3)

    logger.info("NUMBER OF NON-PNP MOLS".ljust(pad) + f"{len(df_pnp_list[~df_pnp_list['pnp_mol']]):,}")

    # end
    d6 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING INPUT FRAGMENT COMBINATION GRAPHS".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: LOADING REFERENCE FRAGMENT COMBINATION GRAPHS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: ANNOTATING FRAGMENT COMBINATION GRAPHS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: SAVING ANNOTATED FRAGMENT COMBINATION GRAPHS".ljust(pad * 2) + f"{d5-d4}")
    logger.info("COMPUTATIONAL TIME: EXPORTING PNP LISTS".ljust(pad * 2) + f"{d6-d5}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d6-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
