#!python

"""
Script mols_count
==========================
This script is used for counting how many molecules are still present at the end
of a step. It is designed to append a row per chunk to a common output file.
"""

# standard
import warnings
from pathlib import Path
import sys
from datetime import datetime
import pandas as pd
import argparse
from pandas import DataFrame
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import load
from npfc import utils
import subprocess
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def find_input_files(input_dir: str, pattern: str, exclude: list = ['/report/', '_ref.hdf', '/log/']) -> list:
    """Find input files corresponding to the pattern.

    :param input_dir: input directory where input files and subdirectories are stored.
    :param pattern: pattern to use for globbing input files. (i.e. 'chembl_001*')
    :param exclude: a list of substrings that can be found in the input file paths to exclude some of those. (i.e. files without the groupby column, etc)
    :return: a sorted list of input files
    """
    p = Path(input_dir).rglob(pattern)
    input_files = [str(x) for x in p]
    for e in exclude:
        input_files = [x for x in input_files if e not in x]
    return sorted(input_files)


def count_mols_sdf(input_sdf: str) -> int:
    """Count the number of molecules for a given SD File by counting the '$$$$'' delimiter.

    :param input_sdf: an input SD File.
    :return: the number of molecules in the input SDF
    """
    pattern = '$$$$'
    p = Path(input_sdf)
    if '.gz' in p.suffixes:
        cmd = f"zgrep -c '{pattern}' {input_sdf}"
    else:
        cmd = f"grep -c '{pattern}' {input_sdf}"
    # count molecules in sdf using a faster bash command
    p = subprocess.Popen(['-c', cmd], shell=True, stdout=subprocess.PIPE)
    out, err = p.communicate()
    # decode bit string and remove trailing \n
    out = int(out.decode('UTF-8').split()[0])
    return out


def count_mols(input_file: str, groupby: str) -> tuple:
    """Count the number of molecules for a given input file.

    :param input_dir: the input file
    :param groupby: the column name to group the input file by.
    :return: a tuple with synthax: (number of rows, number of molecules).
    """
    df = load.file(input_file)
    if len(df) > 0:
        return (len(df), len(df.groupby(groupby)))
    else:
        return (0, 0)


def main():

    # init
    d0 = datetime.now()
    description = """Script is used for counting molecules in each step of the pipeline
    for a given chunk.

    I takes two inputs:

        - input directory
        - pattern for matching the corresponding chunk

    It creates one output:
        - a csv file with all molecule counts for each step as a single row.

    All chunk molecule counts can then easilsy gathered manually using pandas.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> mols_count fc/04_synthetic/chembl/data/prep 'chembl_001*' fc/04_synthetic/chembl/data/prep/report/chembl_001_mols_count.csv

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_dir', type=str, default=None, help="Input directory (i.e. fc/04_synthetic/chembl/data/prep/03_std/data)")
    parser.add_argument('pattern', type=str, default=None, help="Pattern of the file (i.e. prefix of a chunk) to investigate. All matching files in subdirectories will be processed.")
    parser.add_argument('output_csv', type=str, default=None, help="Output file for recording the molecule count.")
    parser.add_argument('-g', '--groupby', type=str, default='idm', help="Column name to use for grouping molecules.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING MOLECULE COUNTING")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments
    utils.check_arg_input_dir(args.input_dir)
    utils.check_arg_output_file(args.output_csv)
    out_format, out_compression = utils.get_file_format(args.output_csv)

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT_DIR".ljust(pad) + f"{args.input_dir}")
    logger.info("PATTERN".ljust(pad) + f"{args.pattern}")
    logger.info("OUTPUT_FILE".ljust(pad) + f"{args.output_csv}")
    logger.info("OUT_FORMAT".ljust(pad) + f"{out_format}")
    logger.info("OUT_COMPRESSION".ljust(pad) + f"{out_compression}")
    logger.info("GROUBPY".ljust(pad) + f"{args.groupby}")
    logger.info("LOG".ljust(pad) + f"{args.log}")
    d1 = datetime.now()

    # begin
    logger.info("BEGIN")

    # finding the input files
    input_files = find_input_files(args.input_dir, args.pattern)
    logger.info('FILES TO PARSE:\n' + '\n'.join(input_files))
    d = {}
    logger.info("STEP".ljust(15) + "NUM_MOLS".ljust(15) + "NUM_ENTRIES".ljust(15) + "INPUT_FILE".ljust(15))
    for input_file in input_files:
        step = Path(input_file).parent.parent.name
        if '.sdf' in Path(input_file).suffixes:
            num_mols = count_mols_sdf(input_file)
            num_rows = num_mols
        else:
            num_rows, num_mols = count_mols(input_file, args.groupby)
        d["pattern"] = [args.pattern]
        d[f"{step}_num_mols"] = [num_mols]
        d[f"{step}_num_entries"] = [num_rows]
        logger.info(f"{step}".ljust(15) + f"{num_mols}".ljust(15) + f"{num_rows}".ljust(15) + f"{input_file}".ljust(15))

    # assemble data into a DF of one single row
    df = DataFrame(d)
    d2 = datetime.now()

    # export results
    if Path(args.output_csv).exists():
        df.to_csv(args.output_csv, mode='a', sep='|', header=False, index=False)
    else:
        df.to_csv(args.output_csv, mode='w+', sep='|', header=True, index=False)

    # end
    d3 = datetime.now()
    logger.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: PARSING INPUT FILES".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: EXPORTING RESULTS".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d3-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
