#!python

"""
Script standardize_mols
==========================
This script is used for standardizing molecules.
"""

# standard
import sys
import warnings
from pathlib import Path
from datetime import datetime
import logging
import argparse
import pkg_resources
# data
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils
from npfc import standardize


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def disable_rdkit_logging():
    """
    Disables RDKit whiny logging.
    https://github.com/rdkit/rdkit/issues/2320
    """
    import rdkit.rdBase as rkrb
    import rdkit.RDLogger as rkl
    logger = rkl.logger()
    logger.setLevel(rkl.ERROR)
    rkrb.DisableLog('rdApp.error')


def main():

    # init
    d0 = datetime.now()
    description = """Script for standardizing molecules.
    Default protocol is:
        1) filter_empty
        2) disconnect_metal
        3) keep_best
        4) filter_num_heavy_atom
        5) filter_molecular_weight
        6) filter_num_ring
        7) filter_elements
        8) remove_isotopes
        9) normalize
        10) uncharge
        11) canonicalize
        12) remove_stereo

    At each step, molecules can thus either pass, get filtered or raise an error.
    In each case, the result is recorded in a corresponding output file.

    A default protocol can be defined by the user too, although it is currently only
    possible to remove steps and not add new ones.

    This script uses the installed npfc libary in your favorite env manager.

    Examples:

        >>> # Apply default protocol:
        >>> standardize_mols file_in.sdf.gz file_out.csv.gz
        >>> # Apply user-defined protocol:
        >>> standardize_mols file_in.sdf.gz file_out.csv.gz -p file_conf.json
        >>> # Apply default protocol but with saving filtered and error files:
        >>> standardize_mols file_in.sdf.gz file_out.csv.gz -f file_filtered.csv.gz -e file_error.csv.gz
        >>> # Apply default protocol with shorter cooldown:
        >>> standardize_mols file_in.sdf.gz file_out.csv.gz -t 5
    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_mols', type=str, default=None, help="Input file.")
    parser.add_argument('output_mols_std', type=str, default=None, help="Output file for standardized molecules.")
    parser.add_argument('-f', '--filtered', type=str, default=None, help="Output file for filtered molecules.")
    parser.add_argument('-e', '--error', type=str, default=None, help="Output file for molecules with errors.")
    parser.add_argument('-p', '--protocol', type=str, default=None, help="Configuration file in JSON for specifying a standardization protocol. See the docs for the default protocol.")
    parser.add_argument('-t', '--timeout', type=int, default=10, help="Maximum time in seconds allowed to standardize a molecule before filterintg it.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging

    logger = utils._configure_logger(args.log)
    logger.info("RUNNING MOLS_STANDARDIZE")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    disable_rdkit_logging()  # does not work... known issue... nothing much to do apart from editing the log file content. will do in smk directly

    # parse arguments

    # check on args values not already checked by argparse
    utils.check_arg_input_file(args.input_mols)
    utils.check_arg_output_file(args.output_mols_std)
    if args.protocol == 'None' or args.protocol == 'DEFAULT' or args.protocol == '':
        args.protocol = None
    if args.protocol is None:
        logger.warning("USING DEFAULT STANDARDIZATION PROTOCOL")
        args.protocol = pkg_resources.resource_filename('npfc', 'data/std_mols.json')
    utils.check_arg_config_file(args.protocol)

    # in infos
    in_format, in_compression = utils.get_file_format(args.input_mols)

    # out infos
    # std
    out_format_std, out_compression_std = utils.get_file_format(args.output_mols_std)
    # filtered
    if args.filtered is not None:
        utils.check_arg_output_file(args.filtered)
        out_format_filtered, out_compression_filtered = utils.get_file_format(args.filtered)
    else:
        out_format_filtered = None
        out_compression_filtered = None
    # output
    if args.error is not None:
        utils.check_arg_output_file(args.error)
        out_format_error, out_compression_error = utils.get_file_format(args.error)

    else:
        out_format_error = None
        out_compression_error = None

    # hard-coded variables
    csv_sep = "|"
    in_id = "idm"
    in_mol = "mol"
    out_id = "idm"
    out_mol = "mol"
    encode = True
    decode = True

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    logger.info("ARGUMENTS:")
    logger.info("INPUT_MOLS".ljust(pad) + f"{args.input_mols}")
    logger.info("OUTPUT_STD".ljust(pad) + f"{args.output_mols_std}")
    logger.info("OUTPUT_FILTERED".ljust(pad) + f"{args.filtered}")
    logger.info("OUTPUT_ERROR".ljust(pad) + f"{args.error}")
    # protocol
    if args.protocol is None:
        logger.info("PROTOCOL".ljust(pad) + "DEFAULT")
    else:
        logger.info("PROTOCOL".ljust(pad) + f"{args.protocol}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # input
    logger.info("HARD-CODED OR DEDUCED VALUES:")
    logger.info("IN_ID".ljust(pad) + f"{in_id}")
    logger.info("IN_MOL".ljust(pad) + f"{in_mol}")
    logger.info("IN_FORMAT".ljust(pad) + f"{in_format}")
    logger.info("IN_COMPRESSION".ljust(pad) + f"{in_compression}")
    logger.info("DECODE".ljust(pad) + f"{decode}")
    logger.info("OUT_FORMAT_STD".ljust(pad) + f"{out_format_std}")
    logger.info("OUT_COMPRESSION_STD".ljust(pad) + f"{out_compression_std}")
    logger.info("OUT_FORMAT_FILTERED".ljust(pad) + f"{out_format_filtered}")
    logger.info("OUT_COMPRESSION_FILTERED".ljust(pad) + f"{out_compression_filtered}")
    logger.info("OUT_FORMAT_ERROR".ljust(pad) + f"{out_format_error}")
    logger.info("OUT_COMPRESSION_ERROR".ljust(pad) + f"{out_compression_error}")
    logger.info("OUT_MOL".ljust(pad) + f"{out_mol}")
    logger.info("OUT_ID".ljust(pad) + f"{out_id}")
    logger.info("ENCODE".ljust(pad) + f"{encode}")
    logger.info("CSV_SEP".ljust(pad) + f"{csv_sep}")

    # begin

    logger.info("BEGIN")

    # describe protocol
    logger.info("STANDARDIZING MOLECULES")
    logger.info("PROTOCOL:")
    s = standardize.Standardizer(protocol=args.protocol, timeout=args.timeout)
    [logger.info(f"TASK #{str(i+1).zfill(2)} {task}") for i, task in enumerate(s._protocol['tasks'])]
    [logger.info(f"OPTION {opt}".ljust(pad) + f"{value}") for opt, value in s._protocol.items() if opt != 'tasks']
    logger.info("TIMEOUT FOR ABOVE TASKS".ljust(pad) + f"{args.timeout} SEC")

    # load mols
    d1 = datetime.now()
    logger.info("LOADING MOLECULES")
    df_mols = load.file(args.input_mols, csv_sep=csv_sep)
    num_failed = df_mols['mol'].isna().sum()
    logger.info(f"LOADED {len(df_mols)} RECORDS WITH {num_failed} FAILURE(S)")

    # standardize molecules
    d2 = datetime.now()
    logger.info("RUN STANDARDIZATION")
    df_passed, df_filtered, df_error = s.run_df(df_mols)
    logger.info("RESULTS:")
    num_passed = len(df_passed.index)
    num_filtered = len(df_filtered.index)
    num_error = len(df_error.index)
    logger.info(''.join([f"{header}".ljust(10) for header in ('HEADER', 'PASSED', 'FILTERED', 'ERROR')]))
    logger.info(''.join([f"{header}".ljust(10) for header in ('COUNT', num_passed, num_filtered, num_error)]))

    d3 = datetime.now()
    logger.info("SAVING OUTPUTS")
    # std
    outputs = save.file(df_passed, args.output_mols_std, csv_sep=csv_sep)
    logger.info(f"SAVED {outputs[0][1]:,} STD RECORDS AT {outputs[0][0]}")
    # filtered
    if args.filtered is not None:
        outputs = save.file(df_filtered, args.filtered, csv_sep=csv_sep)
        logger.info(f"SAVED {outputs[0][1]:,} FILTERED RECORDS AT {outputs[0][0]}")
    # error
    if args.error is not None:
        outputs = save.file(df_error, args.error, csv_sep=csv_sep)
        logger.info(f"SAVED {outputs[0][1]:,} ERROR RECORDS AT {outputs[0][0]}")
    d4 = datetime.now()

    # end

    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING MOLECULES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: STANDARDIZING MOLECULES".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: SAVING OUTPUTS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d4-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
