#!python

"""
Script mols_concat
==========================
This script is used for concatening two datasets.
"""

# standard
import sys
import warnings
from pathlib import Path
from datetime import datetime
import logging
import argparse
# data
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def parse_arg_sortby(sortby: str) -> tuple:
    """Helper function to parse a list of formatted arguments (i.e. 'col1:a, col2:d')
    and return two lists (column names and booleans) to use for pandas sorting.

    By default, sorting is ascending, so 'col1, col2:d' is equivalent to 'col1:a col2:d'.

    :param sortby: a formatted string with each arguments separated by ',' and each argument composed of a column name and 'a' or 'd' separated by ':'.
    :return: a tuple of lists. First list contains the column names and second list contains booleans specifiying the ascending sorting of the column.
    """
    if sortby is None:
        return (None, None)
    arguments = sortby.split(',')
    names = []
    orders = []
    for i, a in enumerate(arguments):
        a = a.strip()
        if a.count(':') > 1:
            raise ValueError(f"ERROR! CANNOT PARSE ARGUMENT #{i+1} ('{a}')!")

        name, order = a.split(':')

        if order == 'a':
            orders.append(True)
        elif order == 'd':
            orders.append(False)
        else:
            raise ValueError(f"ERROR! ORDER SPECIFIED IN ARGUMENT #{i+1} IS INVALID ('{order}')!")
        names.append(name)
    return (names, orders)



def main():

    # init
    d0 = datetime.now()
    description = """Script used for concatenating two moelcular files.

    Warning!!! During import of the files, columns 'dataset' and 'hac' are generated
    and would overwrite any pre-existing column with the same name.

    It uses the installed npfc libary in your favorite env manager.

    Examples:

        >>> # Concatenate two CSV files:
        >>> mols_concatenate file1.csv file2.csv.gz file_concat.csv.gz
        >>> # Concatenate a CSV and a SDF file and sort by hac and dataset:
        >>> mols_concatenate file1.csv.gz file2.sdf.gz file_concat.csv.gz -s 'hac:a, dataset:a'
    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_mols_1', type=str, default=None, help="Input file 1.")
    parser.add_argument('input_mols_2', type=str, default=None, help="Input file 2.")
    parser.add_argument('output_mols', type=str, default=None, help="Output file.")
    parser.add_argument('-s', '--sortby', type=str, default=None, help="A list of column names following synthax 'col1:a, col2:d', for sorting the concatenated mols first by col1 in ascending order, then by col2 in descending order.")
    parser.add_argument('-g', '--groupby', type=str, default='idm', help="A column name to group molecules on.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()
    d0 = datetime.now()

    # logging

    logger = utils._configure_logger(args.log)
    logger.info("RUNNING MOLS_CONCATENATE")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # parse arguments

    # check on args values not already checked by argparse
    utils.check_arg_input_file(args.input_mols_1)
    utils.check_arg_input_file(args.input_mols_2)
    utils.check_arg_output_file(args.output_mols)
    colnames, orders = parse_arg_sortby(args.sortby)

    if args.groupby != 'idm' and args.groupby is not None:
        raise ValueError(f"ERROR! UNSUPPORTED VALUE FOR ARGUMENT GROUPBY ('{args.groupby}')!")

    # IO infos
    in_format_1, in_compression_1 = utils.get_file_format(args.input_mols_1)
    in_format_2, in_compression_2 = utils.get_file_format(args.input_mols_2)
    out_format, out_compression = utils.get_file_format(args.output_mols)

    # hard-coded variables because I am the only one using this tool anyway
    csv_sep = "|"
    in_id = "idm"
    in_mol = "mol"
    out_id = "idm"
    out_mol = "mol"
    encode = True
    decode = True

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    # replace many vars with defined names for simplifying the whole pipeline, these variables might be added back when the structure of npfc does not change anymore
    logger.info("ARGUMENTS:")
    logger.info("INPUT_MOLS_1".ljust(pad) + f"{args.input_mols_1}")
    logger.info("IN_ID_1".ljust(pad) + f"{in_id}")
    logger.info("IN_MOL_1".ljust(pad) + f"{in_mol}")
    logger.info("IN_FORMAT_1".ljust(pad) + f"{in_format_1}")
    logger.info("IN_COMPRESSION_1".ljust(pad) + f"{in_compression_1}")
    logger.info("INPUT_MOLS_2".ljust(pad) + f"{args.input_mols_1}")
    logger.info("IN_ID_2".ljust(pad) + f"{in_id}")
    logger.info("IN_MOL_2".ljust(pad) + f"{in_mol}")
    logger.info("IN_FORMAT_2".ljust(pad) + f"{in_format_1}")
    logger.info("IN_COMPRESSION_2".ljust(pad) + f"{in_compression_1}")
    logger.info("DECODE".ljust(pad) + f"{decode}")
    logger.info("OUTPUT_MOLS".ljust(pad) + f"{args.output_mols}")
    logger.info("OUT_ID".ljust(pad) + f"{out_id}")
    logger.info("OUT_MOL".ljust(pad) + f"{out_mol}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("ENCODE".ljust(pad) + f"{encode}")
    logger.info("CSV_SEP".ljust(pad) + f"{csv_sep}")
    logger.info("SORTBY".ljust(pad) + f"{args.sortby}")
    logger.info("GROUPBY".ljust(pad) + f"{args.groupby}")

    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    # load mols 1
    d1 = datetime.now()
    logger.info("LOADING MOLECULES 1")
    df_mols_1 = load.file(args.input_mols_1, csv_sep=csv_sep)
    df_mols_1['dataset'] = 'dataset_1'
    df_mols_1['idm'] = df_mols_1['idm'].astype(str)
    df_mols_1['hac'] = df_mols_1['mol'].map(lambda x: x.GetNumAtoms())
    num_mols_1 = len(df_mols_1)
    num_failed_1 = df_mols_1['mol'].isna().sum()
    logger.debug(f"DF_MOLS_1:\n\n{df_mols_1}\n")
    logger.info(f"LOADED {num_mols_1:,} RECORDS WITH {num_failed_1:,} FAILURE(S)")

    # load mols 2
    d2 = datetime.now()
    logger.info("LOADING MOLECULES 2")
    df_mols_2 = load.file(args.input_mols_2, csv_sep=csv_sep)
    df_mols_2['idm'] = df_mols_2['idm'].astype(str)
    df_mols_2['dataset'] = 'dataset_2'
    df_mols_2['hac'] = df_mols_2['mol'].map(lambda x: x.GetNumAtoms())
    num_mols_2 = len(df_mols_2)
    num_failed_2 = df_mols_2['mol'].isna().sum()
    logger.debug(f"DF_MOLS_2:\n\n{df_mols_2}\n")
    logger.info(f"LOADED {num_mols_2:,} RECORDS WITH {num_failed_2:,} FAILURE(S)")

    # filter out duplicate molecules
    d3 = datetime.now()
    logger.info(f"CONCATENATING MOLECULES")
    df_mols = pd.concat([df_mols_1, df_mols_2], sort=True)

    # attempt to consider idm as int, if possible
    try:
        logger.debug(f"ATTEMPTING TO CONVERT IDM COL '{in_id}' TO INT")
        df_mols['idm'] = df_mols['idm'].astype(int)
    except ValueError:
        logger.debug(f"IDM COL '{in_id}' CANNOT BE CONVERTED TO INT")

    # sort molecules using specified properties so the first a group is the desired one
    if args.sortby is not None:
        logger.info(f"SORTING MOLECULES")
        df_mols = df_mols.sort_values(colnames, ascending=orders)

    # groupby molecules by idm
    if args.groupby is not None:
        logger.info(f"GROUPING MOLECULES")
        df_mols = df_mols.groupby('idm').first().reset_index()

    logger.debug(f"FINAL RESULTS:")

    # save results
    d4 = datetime.now()
    logger.info(f"SAVING OUTPUT")
    save.file(df_mols, output_file=args.output_mols, csv_sep=csv_sep)
    logger.info(f"SAVED {len(df_mols.index):,} RECORDS")
    d5 = datetime.now()


    # end

    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING MOLECULES 1".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: LOADING MOLECULES 2".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: CONCATENATING DUPLICATES".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: SAVING OUTPUTS".ljust(pad * 2) + f"{d5-d4}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d5-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
