#!python

"""
Script refs_group
==========================
This script is used for grouping the refs on the group_on column.
"""

# standard
import sys
import warnings
from pathlib import Path
from datetime import datetime
import logging
import argparse
# data
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

def main():

    # init
    d0 = datetime.now()
    description = """Script used for grouping references.

    Warning! This script has to load the full input into memory.

    It uses the installed npfc libary in your favorite env manager.

    Examples:

        >>> refs_group input_refs.csv.gz output_refs.hdf
    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_refs', type=str, default=None, help="Input file 1.")
    parser.add_argument('output_refs', type=str, default=None, help="Output file.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()
    d0 = datetime.now()

    # logging

    logger = utils._configure_logger(args.log)
    logger.info("RUNNING REFS_GROUP")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # parse arguments

    # check on args values not already checked by argparse
    utils.check_arg_input_file(args.input_refs)
    utils.check_arg_output_file(args.output_refs)

    # IO infos
    in_format, in_compression = utils.get_file_format(args.input_refs)
    out_format, out_compression = utils.get_file_format(args.output_refs)

    # hard-coded variables because I am the only one using this tool anyway
    csv_sep = "|"

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    # replace many vars with defined names for simplifying the whole pipeline, these variables might be added back when the structure of npfc does not change anymore
    logger.info("ARGUMENTS:")
    logger.info("INPUT_REFS".ljust(pad) + f"{args.input_refs}")
    logger.info("IN_FORMAT".ljust(pad) + f"{in_format}")
    logger.info("IN_COMPRESSION".ljust(pad) + f"{in_compression}")
    logger.info("OUTPUT_REFS".ljust(pad) + f"{args.output_refs}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("CSV_SEP".ljust(pad) + f"{csv_sep}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    # load refs
    d1 = datetime.now()
    logger.info("LOADING REFS")
    df = load.file(args.input_refs, csv_sep=csv_sep)
    df = df.sort_values(['group_on', 'idm_filtered'])
    logger.info(f"LOADED {len(df):,} RECORDS")
    logger.debug(f"DF:\n {df}\n")

    d2 = datetime.now()

    # group refs
    df = df.groupby(['group_on'], as_index=False).agg({'idm_kept': ','.join}).rename({'idm_kept': 'ref'}, axis=1)

    # put the kept ref as first
    # df['ref'] = df.apply(lambda x: x['idm_kept'] + [s for s in x['ref'] if s != x['idm_kept']], axis=1)  # would need to be adapted to work
    logger.info(f"RESULT: {len(df):,} RECORDS")
    logger.debug(f"DF_REF:\n\n{df}\n")
    d3 = datetime.now()

    # save results
    logger.info("SAVING OUTPUT")
    df.to_hdf(args.output_refs, 'df', mode='w', format='table')
    # save.file(df, output_file=args.output_refs, csv_sep=csv_sep)
    logger.info(f"SAVED {len(df.index):,} RECORDS")
    d4 = datetime.now()

    # end

    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING REFS".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: GROUPING REFS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: SAVING OUTPUTS".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d4-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
