#!python

"""
Script compute2D_mols
==========================
This script is used for computing 2D coordinates for molecules.
"""

# standard
import sys
import warnings
from datetime import datetime
import logging
import argparse
# data
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# dev
import npfc
from npfc import load
from npfc import save
from npfc import draw
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """
    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_mols', type=str, default=None, help="Input file.")
    parser.add_argument('output_mols', type=str, default=None, help="Output file.")
    parser.add_argument('-m', '--methods', type=str, default="CoordGen,rdDepictor", help="Methods for computing 2D coordinates. Methods are applied in displayed order. If a depiction yields a 'perfect' score (0), then other following methods are not applied. If no perfect score is reached, the depiction with best (lowest) score is selected.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")

    args = parser.parse_args()

    # logging

    logger = utils._configure_logger(args.log)
    logger.info("RUNNING MOLS_DEPICT")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # parse arguments

    # check on args values not already checked by argparse
    utils.check_arg_input_file(args.input_mols)
    utils.check_arg_output_file(args.output_mols)

    # IO infos
    in_format, in_compression = utils.get_file_format(args.input_mols)
    out_format, out_compression = utils.get_file_format(args.output_mols)

    # Methods
    methods = args.methods.replace(" ", "").split(",")

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    # replace many vars with defined names for simplifying the whole pipeline, these variables might be added back when the structure of npfc does not change anymore
    logger.info("ARGUMENTS:")
    logger.info("INPUT_MOLS".ljust(pad) + f"{args.input_mols}")
    logger.info("IN_ID".ljust(pad) + f"idm")
    logger.info("IN_MOL".ljust(pad) + f"mol")
    logger.info("IN_FORMAT".ljust(pad) + f"{in_format}")
    logger.info("IN_COMPRESSION".ljust(pad) + f"{in_compression}")
    logger.info("DECODE".ljust(pad) + f"True")
    logger.info("OUTPUT_MOLS".ljust(pad) + f"{args.output_mols}")
    logger.info("OUT_ID".ljust(pad) + f"idm")
    logger.info("OUT_MOL".ljust(pad) + f"mol")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("ENCODE".ljust(pad) + f"True")
    logger.info("CSV_SEP".ljust(pad) + f"|")
    logger.info("METHODS".ljust(pad) + f"{args.methods}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    # load mols
    d1 = datetime.now()
    logger.info("LOADING MOLECULES")
    df_mols = load.file(args.input_mols)
    num_failed = df_mols['mol'].isna().sum()
    logger.info(f"LOADED {len(df_mols)} RECORDS WITH {num_failed} FAILURE(S)")

    # extract Murcko scaffolds
    d2 = datetime.now()
    logger.info("COMPUTING 2D COORDINATES")
    df_mols['mol'] = df_mols['mol'].map(lambda x: draw.depict_mol(x, methods=methods))

    # save results
    d3 = datetime.now()
    logger.info(f"SAVING OUTPUTS")
    save.file(df_mols, output_file=args.output_mols)
    d4 = datetime.now()

    # end

    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING MOLECULES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: COMPUTING 2D".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: SAVING OUTPUTS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d4-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
