#!python

"""
Script frags_search
==========================
This script is used for searching for fragments in molecules.
"""

# standard
import sys
import warnings
from datetime import datetime
import argparse
# data handling
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import load
from npfc import save
from npfc import utils
from npfc import fragment_search

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Search for substructures in molecules.

    This command uses the installed npfc libary in your favorite env manager.

    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('minput', type=str, default=None, help="Input file.")
    parser.add_argument('-m', '--mid', type=str, default='idm', help="Identifier column in the molecules file. If not specified, rowids are used instead (_Name for SDF).")
    parser.add_argument('--mdecode', type=bool, default=True, help="Decode molecules from base64 strings.")
    parser.add_argument('--mmol', type=str, default='mol', help="Molecule column for csv or hdf files.")
    parser.add_argument('--msep', type=str, default='|', help="Separator to use in case the input file is a csv.")
    parser.add_argument('finput', type=str, default=None, help="Input frags.")
    parser.add_argument('--fid', type=str, default='idm', help="Identifier column in the fragments file. If not specified, rowids are used instead (_Name for SDF).")
    parser.add_argument('--fdecode', type=bool, default=True, help="Decode fragments from base64 strings.")
    parser.add_argument('--fmol', type=str, default='mol', help="Molecule column for csv or hdf files.")
    parser.add_argument('--fsep', type=str, default='|', help="Separator to use in case the input file is a csv.")
    parser.add_argument('output', type=str, default=None, help="Output file with substructures.")
    parser.add_argument('--oencode', type=bool, default=True, help="Encode molecules and fragments in base64 strings (mandatory for csv files).")
    parser.add_argument('--omol', type=str, default='mol', help="Molecule column for csv or hdf files.")
    parser.add_argument('--ofrag', type=str, default='mol_frag', help="Molecule column for csv or hdf files.")
    parser.add_argument('--osep', type=str, default='|', help="Separator to use in case the input file is a csv.")
    parser.add_argument('-f', '--filter-frags', type=str, default=None, help="List of IDs to filter out from the input fragment combinations.")
    parser.add_argument('-t', '--tautomer', type=str, default=None, help="Consider tautomers for molecules during fragment search.")
    parser.add_argument('--fcp', type=str, default='_fcp_labels', help="Column name in the fragments DataFrame to use for labelling fragment connection points. If an empty string '' is provided, these will not be registered.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING STANDARDIZE_MOLS")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # parse arguments

    # minput
    utils.check_arg_input_file(args.minput)
    mformat, mcompression = utils.get_file_format(args.minput)

    # finput
    utils.check_arg_input_file(args.finput)
    fformat, fcompression = utils.get_file_format(args.finput)

    # output
    utils.check_arg_output_file(args.output)
    oformat, ocompression = utils.get_file_format(args.output)
    if oformat == ".csv" and not args.oencode:
        logger.warning(f"Output format is {oformat} but oencode is set to {args.oencode}, setting it to True")
        oencode = True
    else:
        oencode = args.oencode

    # filter frags
    if args.filter_frags is None:
        filter_frags = []
    elif ',' in filter_frags:
        filter_frags = tuple([x.strip() for x in filter_frags.split(',')])
    else:
        filter_frags = tuple([filter_frags])

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    logger.info("ARGUMENTS:")
    logger.info("MINPUT".ljust(pad) + f"{args.minput}")
    logger.info("MFORMAT".ljust(pad) + f"{mformat}")
    logger.info("MCOMPRESSION".ljust(pad) + f"{mcompression}")
    logger.info("MID".ljust(pad) + f"{args.mid}")
    logger.info("MMOL".ljust(pad) + f"{args.mmol}")
    if mformat == '.csv':
        logger.info("MSEP".ljust(pad) + f"{args.msep}")
    if mformat != '.sdf':
        logger.info("MDECODE".ljust(pad) + f"{args.msep}")
    logger.info("FINPUT".ljust(pad) + f"{args.finput}")
    logger.info("FFORMAT".ljust(pad) + f"{fformat}")
    logger.info("FCOMPRESSION".ljust(pad) + f"{fcompression}")
    logger.info("FID".ljust(pad) + f"{args.fid}")
    logger.info("FMOL".ljust(pad) + f"{args.fmol}")
    if fformat == '.csv':
        logger.info("FSEP".ljust(pad) + f"{args.fsep}")
    if fformat != '.sdf':
        logger.info("FDECODE".ljust(pad) + f"{args.fdecode}")
    logger.info("OUTPUT".ljust(pad) + f"{args.output}")
    logger.info("OFORMAT".ljust(pad) + f"{oformat}")
    logger.info("OCOMPRESSION".ljust(pad) + f"{ocompression}")
    logger.info("OENCODE".ljust(pad) + f"{oencode}")
    logger.info("OMOL".ljust(pad) + f"{args.omol}")
    logger.info("OFRAG".ljust(pad) + f"{args.ofrag}")
    logger.info("OSEP".ljust(pad) + f"{args.osep}")
    logger.info('FILTER_FRAGS'.ljust(pad) + f"{', '.join(filter_frags)}")
    logger.info("TAUTOMER".ljust(pad) + f"{args.tautomer}")
    logger.info("FCP".ljust(pad) + f"{args.fcp}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin

    logger.info("BEGIN")

    # load mols
    logger.info("LOADING MOLECULES")
    d1 = datetime.now()
    df_mols = load.file(args.minput,
                        in_id=args.mid,
                        out_id=args.mid,
                        in_mol=args.mmol,
                        out_mol=args.mmol,
                        keep_props=True,
                        csv_sep=args.msep,
                        decode=args.mdecode,
                        )
    num_failed = df_mols['mol'].isna().sum()
    logger.info(f"LOADED {len(df_mols):,} RECORDS WITH {num_failed:,} FAILURE(S)")

    # load frags
    logger.info("LOADING FRAGMENTS")
    d2 = datetime.now()
    df_frags = load.file(args.finput,
                         in_id=args.fid,
                         out_id=args.fid,
                         in_mol=args.fmol,
                         out_mol=args.fmol,
                         csv_sep=args.fsep,
                         decode=args.fdecode,
                         )
    num_failed = df_frags['mol'].isna().sum()
    logger.info(f"LOADED {len(df_frags):,} RECORDS WITH {num_failed:,} FAILURE(S)")

    if args.fcp != '' and args.fcp not in df_frags.columns:
        raise ValueError(f"ERROR! FCP LABELS ARE SPECIFIED BUT COLUMN COULD NOT BE FOUND IN FRAGMENTS DATAFRAME!\n({args.fcp} NOT IN ', '.join({df_frags.columns}))")

    # eventually filter out some fragments
    if len(filter_frags) > 0:
        logger.info(f"FILTERING FRAGMENT COMBINATIONS WITH FRAG(S): {', '.join(filter_frags)}")
        df_frags = df_frags[~df_frags['idm'].isin(filter_frags)]
        logger.info(f"REMAINING FRAGMENTS: {len(df_frags):,}")

    # fragment search (substructure search)
    logger.info("FRAGMENT SEARCH")
    d3 = datetime.now()
    df_aidxf = fragment_search.get_fragment_hits(df_mols, df_frags, tautomer=args.tautomer, fcp_labels=args.fcp)
    logger.info(f"FOUND {len(df_aidxf.index):,} FRAGMENT HITS")

    # save results
    logger.info("SAVING RESULTS")
    d4 = datetime.now()
    outputs = save.file(df_aidxf, args.output, encode=True)
    logger.info(f"SAVED {outputs[0][1]:,} RECORDS AT {outputs[0][0]}")

    # end
    d5 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING MOLECULES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: LOADING FRAGMNENTS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: FRAGMENT SEARCH".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: SAVING OUTPUT".ljust(pad * 2) + f"{d5-d4}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d5-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
