#!python

"""
Script subset_mols
==========================
This script is used for subsetting a molecular dataset given another used as reference.
Currently, molecules in common can only be excluded from the molecular file.
"""

# standard
import sys
import warnings
from datetime import datetime
import logging
import argparse
from pathlib import Path
# data handling
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Subset a molecular file using synonym files.

    Every compound in common between input-syn and ref-syn are removed from input-mols.

    Require 2 synonym files and a molecular file. The latter should have a column
    called 'idm' for referencing compounds.


    For CSV files, delimiter '|' only is used.

    This command uses the installed npfc libary in your favorite env manager.

    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('mols', type=str, default=None, help="Input file for the molecules to filter.")
    parser.add_argument('synref', type=str, default=None, help="Synonym file for the reference to use as filter.")
    parser.add_argument('output', type=str, default=None, help="Output file.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING SUBSET_MOLS")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # parse arguments

    # mols
    utils.check_arg_input_file(args.mols)
    mformat, mcompression = utils.get_file_format(args.mols)

    # synref
    utils.check_arg_input_file(args.synref)
    fformat, fcompression = utils.get_file_format(args.synref)

    # output
    utils.check_arg_output_file(args.output)
    oformat, ocompression = utils.get_file_format(args.output)

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    logger.info("ARGUMENTS:")
    logger.info("mols".ljust(pad) + f"{args.mols}")
    logger.info("synref".ljust(pad) + f"{args.synref}")
    logger.info("output".ljust(pad) + f"{args.output}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin

    logger.info("BEGIN")

    # loading synref
    synref_key = Path(args.synref).stem
    logger.info(f"LOADING SYNREF WITH KEY={synref_key}")
    d1 = datetime.now()
    df_synref = pd.read_hdf(args.synref, decode=False)
    logger.info(f"NUMBER OF RECORDS IN SYNREF: {len(df_synref.index)}")

    # load mols
    logger.info("LOADING MOLS")
    d2 = datetime.now()
    df_mols = pd.read_csv(args.mols, sep="|", compression="gzip")  # no need for decoding mols
    # subset mols
    logger.info(f"SUBSETTING MOLS")
    d3 = datetime.now()
    num_ini = len(df_mols.index)
    df_mols_subset = df_mols[~df_mols['inchikey'].isin(df_synref['inchikey'])]
    logger.info(f"NUMBER OF REMAINING RECORDS IN SUBSET: {len(df_mols_subset.index)}/{num_ini}")

    # save results
    logger.info(f"SAVING RESULTS")
    d4 = datetime.now()
    df_mols_subset.to_csv(args.output, sep="|", compression="gzip")

    # end
    d5 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING SYNREF".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: LOADING INPUT_MOL".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: SUBSETIING MOLS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: SAVING SUBSET".ljust(pad * 2) + f"{d5-d4}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d5-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
