#!python

"""
Script load_mols
==========================
This script is used for loading molecules from SDF, CSV or HDF files and then
export them to CSV or HDF files with RDKit Mol objects.
"""

# standard
from pathlib import Path
import sys
import warnings
from datetime import datetime
import logging
import argparse
# data
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# custom libraries
import npfc
from npfc import load
from npfc import save
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """Script used for loading molecules from SDF, CSV or HDF files, convert them into RDKit objects and export them into CSV or HDF files.
    Molecules that failed the RDKit conversion have None for structure but their properties, if any, remain.

    It uses the installed npfc libary in your favorite env manager.

    Examples:

        >>> # Convert a SDF into a HDF using molecule titles as molecule id
        >>> load_mols file_in.sdf file_out.hdf --in_id _Name
        >>> # Chunk a CSV file into SDF files of 100 randomly ordered records while keeping all properties
        >>> load_mols file_in.csv file_out.sdf -n 100 -k True -s True --in_id prop1 --in_mol mol --out_id _Name
    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_mols', type=str, default=None, help="Input file.")
    parser.add_argument('output_mols', type=str, default=None, help="Output file. If chunking, then it is used as prefix.")
    parser.add_argument('-n', '--nrecords', type=int, default=None, help="Maximum number of records within a chunk. None means no chunking.")
    parser.add_argument('-s', '--shuffle', type=bool, default=False, help="Shuffle records.")
    parser.add_argument('-k', '--keep_props', type=bool, default=False, help="Keep all properties found in the source, in addition to molid and mol columns.")
    parser.add_argument('-e', '--encode', type=bool, default=True, help="Encode objects into base64 strings.")
    parser.add_argument('-d', '--decode', type=bool, default=True, help="Decode objects into base64 strings.")
    parser.add_argument('--in_id', type=str, default='idm', help="Identifier column in the source file.")
    parser.add_argument('--in_mol', type=str, default='mol', help="Molecule column in the source file.")
    parser.add_argument('--in_sep', type=str, default='|', help="Separator to use in case the input file is a csv.")
    parser.add_argument('--out_id', type=str, default='idm', help="Identifier column in the output file.")
    parser.add_argument('--out_mol', type=str, default='mol', help="Molecule column in the output file.")
    parser.add_argument('--out_sep', type=str, default='|', help="Separator to use in case the output file is a csv.")
    parser.add_argument('--random_seed', type=int, default=None, help="Value to specify to make shuffling reproducible.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging

    logger = utils._configure_logger(args.log)
    logger.info("RUNNING LOAD_MOLS")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # parse arguments

    # check on args values not already checked by argparse
    if ',' in args.input_mols:
        input_files = [x.strip() for x in args.input_mols.split(',')]
    elif Path(args.input_mols).is_file():
        input_files = [args.input_mols]
    elif Path(args.input_mols).is_dir():
        input_files = sorted([str(x) for x in Path(args.input_mols).glob('*')])
    else:
        raise ValueError('ERROR! DOES NOT KNOW WHAT TO DO WITH INPUT MOLS (%s)' % args.input_mols)

    if len(input_files) < 1:
        raise ValueError('ERROR! COULD NOT FIND ANY INPUT MOLS AT "%s"' % args.input_mols)

    for input_file in input_files:
        print(f"input_file: {input_file}")
        utils.check_arg_input_file(input_file)

    utils.check_arg_output_file(args.output_mols)
    utils.check_arg_positive_number(args.nrecords)
    # src infos
    in_format, in_compression = utils.get_file_format(input_files[0])  # input files are supposed to be homogeneous
    # out infos
    out_format, out_compression = utils.get_file_format(args.output_mols)

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    logger.info("ARGUMENTS")
    if len(input_files) > 1:
        for i, input_file in enumerate(input_files):
            logger.info(f"INPUT_FILE {str(i+1).zfill(2)}".ljust(pad) + f"{input_file}")
    else:
        logger.info("INPUT_FILE".ljust(pad) + f"{input_files[0]}")
    logger.info("OUTPUT_MOLS".ljust(pad) + f"{args.output_mols}")
    logger.info("NRECORDS".ljust(pad) + f"{args.nrecords}")
    logger.info("SHUFFLE".ljust(pad) + f"{args.shuffle}")
    logger.info("KEEP_PROPS".ljust(pad) + f"{args.keep_props}")
    logger.info("ENCODE".ljust(pad) + f"{args.encode}")
    logger.info("DECODE".ljust(pad) + f"{args.decode}")
    logger.info("IN_ID".ljust(pad) + f"{args.in_id}")
    logger.info("IN_MOL".ljust(pad) + f"{args.in_mol}")
    logger.info("IN_FORMAT".ljust(pad) + f"{in_format}")
    logger.info("IN_COMPRESSION".ljust(pad) + f"{in_compression}")
    logger.info("OUT_MOL".ljust(pad) + f"{args.out_mol}")
    logger.info("OUT_ID".ljust(pad) + f"{args.out_id}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")

    # load mols
    logger.info("LOADING MOLECULES")
    d1 = datetime.now()
    df_mols = pd.concat([load.file(x,
                                   in_id=args.in_id,
                                   out_id=args.out_id,
                                   in_mol=args.in_mol,
                                   out_mol=args.out_mol,
                                   keep_props=args.keep_props,
                                   decode=args.decode,
                                   csv_sep=args.in_sep
                                   ) for x in input_files])

    # filter out failed molecules
    d2 = datetime.now()
    df_failed = df_mols[df_mols['mol'].isna()]
    num_failed = len(df_failed.index)
    logger.info(f"LOADED {len(df_mols)} RECORDS WITH {num_failed} FAILURE(S)")
    df_mols = df_mols[~ df_mols.index.isin(df_failed.index)]

    # save mols
    d3 = datetime.now()
    logger.info("SAVING MOLECULES")
    outputs_files = save.file(df_mols,
                              args.output_mols,
                              col_mol=args.out_mol,
                              col_id=args.out_id,
                              chunk_size=args.nrecords,
                              shuffle=args.shuffle,
                              encode=args.encode,
                              random_seed=args.random_seed,
                              csv_sep=args.out_sep,
                              )
    # display results
    for output in outputs_files:
        logger.info(f"SAVED {output[1]} RECORDS AT {output[0]}")
    d4 = datetime.now()

    # end

    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING MOLECULES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: FILTERING OUT FAILED".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: SAVING MOLECULES".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d4-d0}")
    logger.info("END")
    sys.exit(0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RUN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
