#!/usr/bin/env python

"""
Script mols_count
==================
This script is used for counting how many molecules there are in an input file.
"""

# standard
from multiprocessing.sharedctypes import Value
import sys
import warnings
from pathlib import Path
from datetime import datetime
import logging
import argparse
import random
import time
import json
# data
import pandas as pd
# chemoinformatics
import rdkit
from rdkit import RDLogger
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def main():

    # init
    d0 = datetime.now()
    description = """mols_count
    A script for counting how many molecules there are in an input file. Molecules are NOT parsed.
    A JSON file can be generated to record the number of molecules for further processing (output_count).
    By default, the count is not performed again if the specified output_count file already exists. 

    As for most hyper-specialized tools, some methods are more optimized than others. By default this tool
    will try to use UNIX commands to retrieve the count, but can fall back to Python only, if needed. 

    For now, only three input formats are supported: CSV (SMILES), HDF and SDF.
    """

    # parameters CLI

    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_mols', type=str, default=None, help="Input file.")
    parser.add_argument('-o', '--output-count', type=str, default=None, help="A JSON output file where to store the mols count.")
    parser.add_argument('-f', '--force', type=utils.parse_argparse_boolstring, default=False, help="Force the script to ignore (and recompute) the specified JSON output_count file")
    parser.add_argument('-e', '--engine', type=str, default='unix', help="Either 'python' or 'unix'. The latter is used to run UNIX commands, when available for SDF and CSV input formats. If not available, fall-back to pure-Python.")
    parser.add_argument('-u', '--keep-uncompressed', type=utils.parse_argparse_boolstring, default=False, help="In case of a compressed input file (gzip), do not delete the uncompressed temporary file after execution.")
    parser.add_argument('-k', '--hdf-key', type=str, default=None, help="For HDF input only, by default the filename is used as key.")
    parser.add_argument('--csv-header', type=utils.parse_argparse_boolstring, default=True, help="For CSV input only, specifiy whether the input file has headers.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()
    d0 = datetime.now()

    # logging

    logger = utils._configure_logger(args.log, reset_handlers=False)
    logger.info("RUNNING MOLS_COUNT")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40
    lg = RDLogger.logger()
    lg.setLevel(RDLogger.CRITICAL)

    # parse arguments

    # check on args values not already checked by argparse
    utils.check_arg_input_file(args.input_mols)
    if args.output_count is not None:
        utils.check_arg_output_config(args.output_count)
    utils.check_arg_bool(args.force)
    utils.check_arg_bool(args.keep_uncompressed)
    
    # engine
    ENGINES = ['python', 'unix']
    if args.engine not in ENGINES:
        raise ValueError("ERROR! ENGINE '%s' NOT FOUND IN %s", args.engine, ', '.join(ENGINES))

    # IO infos
    in_format, in_compression = utils.get_file_format(args.input_mols)

    # hdf_key
    if in_format == 'HDF':
        if args.hdf_key is None:
            hdf_key = Path(args.input_mols).stem
        else:
            hdf_key = args.hdf_key
    else:
        hdf_key = None

    # display infos

    # versions
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    # arguments
    logger.info("ARGUMENTS:")
    logger.info("INPUT_MOLS".ljust(pad) + f"{args.input_mols}")
    logger.info("IN_FORMAT".ljust(pad) + f"{in_format}")
    logger.info("IN_COMPRESSION".ljust(pad) + f"{in_compression}")
    if in_format == 'HDF':
        logger.info("IN_HDF_KEY".ljust(pad) + f"{hdf_key}")
    elif in_format == 'CSV':
        logger.info("IN_CSV_HEADER".ljust(pad) + f"{args.csv_header}")
    logger.info("OUTPUT_COUNT_FILE".ljust(pad) + f"{args.output_count}")
    logger.info("FORCE".ljust(pad) + f"{args.force}")
    logger.info("ENGINE".ljust(pad) + f"{args.engine}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logger.info("BEGIN")
    d1 = datetime.now()
    if args.output_count is not None and not args.force:
        # do not perform count but rather read result from file
        with open(args.output_count) as FILE:
            data = json.load(FILE)
            num_mols = data['num_mols']
        logger.info("NUMBER OF MOLECULES FOUND IN COUNT FILE: %s", '{:,}'.format(num_mols))
    else:
        # perform counting 
        num_mols = load.count_mols(args.input_mols,
                                   keep_uncompressed=args.keep_uncompressed,
                                   hdf_key=hdf_key,
                                   csv_header=args.csv_header,
                                   engine=args.engine,
                                   )
        logger.info("NUMBER OF MOLECULES FOUND IN INPUT FILE: %s", '{:,}'.format(num_mols))

        # (over)write count file
        if args.output_count is not None and args.force:
            with open(args.output_count, 'w', encoding='utf-8') as FILE:
                json.dump({'num_mols': num_mols}, FILE)

    d2 = datetime.now()

    # end

    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: COUNTING MOLECULES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d2-d0}")
    logger.info("END")
    sys.exit(0)


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
