#!python

"""
Script filter_dupl_mols
==========================
This script is used for filtering out duplicated molecules.
"""

# standard
import sys
import warnings
from pathlib import Path
from datetime import datetime
import logging
import argparse
import random
import time
# data
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# dev
import npfc
from npfc import load
from npfc import save
from npfc import deduplicate
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """
    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_mols', type=str, default=None, help="Input file.")
    parser.add_argument('output_mols', type=str, default=None, help="Output file.")
    parser.add_argument('-g', '--group-on', type=str, default='inchikey', help="The column to use for grouping molecules. Currently two possible values: 'inchikey' or 'smiles'. If these columns are not present in the DataFrame, they are computed using the 'mol' column.")
    parser.add_argument('--recompute-group', type=bool, default=False, help="Force the computation of the grouping column even if it is found in the input DataFrame.")
    parser.add_argument('-r', '--ref-file', type=str, default=None, help="A reference CSV file (.csv) where kept molecules are recorded. Useful for filtering duplicate molecules accross chunks. In case of parallel processing, a lock system prevents simultaneous reading/writing.")
    parser.add_argument('-d', '--dupl-file', type=str, default=None, help="A CSV file for exporting the list of duplicates. If none is defined, the user has to run this script with log level at least at DEBUG for reviewing the duplicates.")
    parser.add_argument('-s', '--syn-file', type=str, default=None, help="A CSV file for exporting the list of synonms (kept + filtered entries). If none is defined, the user has to run this script with log level at least at DEBUG for reviewing the duplicates.")
    parser.add_argument('-c', '--clear', type=bool, default=False, help="Clear the defined reference file if it already exists.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")

    args = parser.parse_args()
    d0 = datetime.now()

    # logging

    logger = utils._configure_logger(args.log, reset_handlers=False)
    # log_format = "%(asctime)s -- %(levelname)s -- %(message)s"
    # logging.basicConfig(format=log_format)
    # logger = logging.getLogger()
    # logger.setLevel(logging.DEBUG)  # ##### this will need some fixing..... in time!
    # logger.info('test')
    logger.info("RUNNING FILTER_DUPL_MOLS")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    waiting_time = random.randint(1, 1000)
    waiting_time = waiting_time / 1000  # time in ms
    logger.info("WAITING FOR %s MS BEFORE STARTING..." % waiting_time)
    time.sleep(waiting_time / 1000)
    logger.info("RUNNING FILTER_DUPL_MOLS")

    # parse arguments

    # check on args values not already checked by argparse
    utils.check_arg_input_file(args.input_mols)
    utils.check_arg_output_file(args.output_mols)
    utils.check_arg_output_file(args.ref_file)
    if args.dupl_file is not None:
        utils.check_arg_output_file(args.dupl_file)
    if args.syn_file is not None:
        utils.check_arg_output_file(args.syn_file)

    if args.clear:
        if args.ref_file is None:
            logger.warning("WARNING! NO REFERENCE FILE IS DEFINED FOR CLEARING")
        else:
            p = Path(args.ref_file)
            if p.exists():
                logger.info(f"CLEARING EXISTING REF_FILE AT '{args.ref_file}'")
            else:
                logger.warning(f"REF_FILE COULD NOT BE FOUND AT '{args.ref_file}', SO NOTHING TO CLEAR")

    # IO infos
    in_format, in_compression = utils.get_file_format(args.input_mols)
    out_format, out_compression = utils.get_file_format(args.output_mols)
    ref_format, ref_compression = utils.get_file_format(args.ref_file)
    # if out_format != "HDF":
    #     raise ValueError("ERROR! output_mols file needs to be a HDF file (.hdf).")

    # hard-coded variables
    csv_sep = "|"
    in_id = "idm"
    in_mol = "mol"
    out_id = "idm"
    out_mol = "mol"
    encode = True
    decode = True

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    # replace many vars with defined names for simplifying the whole pipeline, these variables might be added back when the structure of npfc does not change anymore
    logger.info("ARGUMENTS:")
    logger.info("INPUT_MOLS".ljust(pad) + f"{args.input_mols}")
    logger.info("IN_ID".ljust(pad) + f"{in_id}")
    logger.info("IN_MOL".ljust(pad) + f"{in_mol}")
    logger.info("IN_FORMAT".ljust(pad) + f"{in_format}")
    logger.info("IN_COMPRESSION".ljust(pad) + f"{in_compression}")
    logger.info("DECODE".ljust(pad) + f"{decode}")
    logger.info("OUTPUT_MOLS".ljust(pad) + f"{args.output_mols}")
    logger.info("OUT_ID".ljust(pad) + f"{out_id}")
    logger.info("OUT_MOL".ljust(pad) + f"{out_mol}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("ENCODE".ljust(pad) + f"{encode}")
    logger.info("CSV_SEP".ljust(pad) + f"{csv_sep}")
    logger.info("GROUP_ON".ljust(pad) + f"{args.group_on}")
    logger.info("RECOMPUTE_GROUP".ljust(pad) + f"{args.recompute_group}")
    logger.info("REF_FILE".ljust(pad) + f"{args.ref_file}")
    logger.info("DUPL_FILE".ljust(pad) + f"{args.dupl_file}")
    logger.info("SYN_FILE".ljust(pad) + f"{args.syn_file}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    # load mols
    d1 = datetime.now()
    logger.info("LOADING MOLECULES")
    df_mols = load.file(args.input_mols, csv_sep=csv_sep)
    num_mols = len(df_mols)
    num_failed = df_mols['mol'].isna().sum()
    logger.info(f"LOADED {num_mols} RECORDS WITH {num_failed} FAILURE(S)")

    # filter out duplicate molecules
    d2 = datetime.now()
    logger.info("FILTERING DUPLICATE MOLECULES")
    # recomputing group_on column
    if args.recompute_group:
        if args.group_on in df_mols.columns:
            logger.info(f"RECOMPUTING GROUP_ON COLUMN ('{args.group_on}')")
            df_mols.drop(args.group_on, axis=1, inplace=True)
        else:
            logger.warning(f"GROUP_ON COLUMN ('{args.group_on}') WAS NOT FOUND IN DATAFRAME, NEED TO COMPUTE IT")
    # filtering duplicate molecules
    df_mols, df_dupl = deduplicate.filter_duplicates(df_mols, group_on=args.group_on, ref_file=args.ref_file, get_df_dupl=True)

    logger.info(f"REMAINING MOLECULES: {len(df_mols.index)}/{num_mols}")
    logger.debug(f"FILTERED MOLECULES:\n{df_dupl}\n")

    if args.dupl_file is not None:
        save.file(df_dupl, output_file=args.dupl_file, csv_sep='|')

    if args.syn_file is not None:
        # generating the synonmys df
        df_syn = df_mols[['inchikey', 'idm']].rename({'idm': 'idm_kept', 'inchikey': 'group_on'}, axis=1)
        df_syn['idm_filtered'] = df_syn['idm_kept']
        df_syn = pd.concat([df_syn, df_dupl]).reset_index(drop=True).sort_values('idm_filtered')
        save.file(df_syn, output_file=args.syn_file, csv_sep='|')

    # save results
    d3 = datetime.now()
    logger.info("SAVING OUTPUTS")
    save.file(df_mols, output_file=args.output_mols, csv_sep=csv_sep)
    d4 = datetime.now()

    # end

    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING MOLECULES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: FILTERING DUPLICATES".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: SAVING OUTPUTS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d4-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
