#!python

"""
Script filter_bn
==========================
This is quick and dirty script to filter out benzene fragment hits from DNP/ChEMBL runs (fs and fcc).

It creates a new folder next to the specified input fragment subdir (i.e. frags_crms, frags_crms_nobn),
filled with relevent data and logs. It does not work with fcg because the presence of absence of a single fragment
can drastically change the outcome.
"""

from pathlib import Path
import re
import argparse
from datetime import datetime
import npfc
from npfc import load
from npfc import save
from npfc import utils
import sys

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def fs_filter_fragments(input_file, output_file, idfs=['32']):
    """This function removes any occurrence of specified idfs from Fragment Hits
    """
    df = load.file(input_file, decode=False)
    n_ini = len(df)
    df = df[~df['idf'].isin(idfs)]
    save.file(df, output_file, encode=False)

    return (len(df), n_ini - len(df), n_ini)


def fc_filter_fragments(input_file, output_file, idfs=['32']):
    """This function removes any occurrence of specified idfs from Fragment Combinations
    """
    df = load.file(input_file, decode=False)
    n_ini = len(df)
    df = df[~df['idf1'].isin(idfs)]
    df = df[~df['idf2'].isin(idfs)]
    save.file(df, output_file, encode=False)

    return (len(df), n_ini - len(df), n_ini)


def fs_edit_log(log_in, log_out, new_count):
    """This function is not smart at all and just helps me remember that data files were edited.
    I just add a header to the log file stating that it was modified.
    """
    reading_file = open(log_in, "r")
    new_file_content = ""

    for line in reading_file:
        stripped_line = line.strip()
        new_line = stripped_line
        new_line = re.sub('RUNNING', "RUNNING (EDITED!!!)", new_line)
        new_line = re.sub('FOUND [0-9]+ HITS', f"FOUND {new_count} HITS", new_line)
        new_line = re.sub('SAVED [0-9]+ RECORDS', f"SAVED {new_count} RECORDS", new_line)
        new_file_content += new_line + "\n"

    reading_file.close()
    writing_file = open(log_out, "w+")
    writing_file.write(new_file_content)
    writing_file.close()


def fc_edit_log(log_in, log_out, new_count):
    """This function is not smart at all and just helps me remember that data files were edited.
    I just add a header to the log file stating that it was modified.
    """
    reading_file = open(log_in, "r")
    new_file_content = ""

    for line in reading_file:
        stripped_line = line.strip()
        new_line = stripped_line
        new_line = re.sub('RUNNING', "RUNNING (EDITED!!!)", new_line)
        new_line = re.sub('FOUND [0-9]+ COMBINATIONS', f"FOUND {new_count} COMBINATIONS", new_line)
        new_line = re.sub('SAVED [0-9]+ RECORDS', f"SAVED {new_count} RECORDS", new_line)
        new_file_content += new_line + "\n"

    reading_file.close()
    writing_file = open(log_out, "w+")
    writing_file.write(new_file_content)
    writing_file.close()


def main():

    # init
    d0 = datetime.now()
    description = """Script used to filter out benzene fragment hits from DNP/ChEMBL runs (fsearch).

    It creates a new folder next to the specified input fragment subdir (i.e. frags_crms, frags_crms_nobn)

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> filter_bn frags_crms frags_crms_nobn

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_frags_subdir', type=str, default=None, help="Input frags subidr")
    parser.add_argument('output_frags_subdir', type=str, default=None, help="Output frags subidr.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING FILTER_BN")
    d0 = datetime.now()
    pad = 40

    # parse arguments
    utils.check_arg_input_dir(args.input_frags_subdir)
    utils.check_arg_output_dir(args.output_frags_subdir)

    # WD_IN
    WD_IN = args.input_frags_subdir  # the frags_xxxx subfolder in the relevent dataset folder tree
    WD_OUT = args.output_frags_subdir

    # GET STEPS
    steps = sorted([d for d in Path(WD_IN).glob('*') if d.is_dir()])
    subdir_fs = [x for x in steps if x.name.endswith('_fs')][0]
    subdir_fcc = [x for x in steps if x.name.endswith('_fcc')][0]

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info('INPUT FRAGS SUBDIR'.ljust(pad) + f"{WD_IN}")
    logger.info('OUTPUT FRAGS SUBDIR'.ljust(pad) + f"{WD_OUT}")
    logger.info('STEP FS'.ljust(pad) + f"{subdir_fs}")
    logger.info('STEP FCC'.ljust(pad) + f"{subdir_fcc}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # BEGIN
    logger.info("RUNNING FILTER BN")

    # FS
    logger.info(f"\nPROCESSING: {subdir_fs.name}")
    counts_ini = []
    counts_passed = []
    counts_filtered = []
    WD_OUT_LOG_p = Path(f"{WD_OUT}/{subdir_fs.name}/log")
    if not WD_OUT_LOG_p.exists():
        WD_OUT_LOG_p.mkdir(parents=True, exist_ok=True)
    # data
    logger.info("NOW UPDATING DATA...")
    chunks_in = sorted([x for x in subdir_fs.glob('data/*')])
    chunks_out = [f"{WD_OUT}/{subdir_fs.name}/data/{x.name}" for x in chunks_in]
    chunks = [(str(x), y) for x, y in zip(chunks_in, chunks_out)]
    for chunk_in, chunk_out in chunks:
        passed, filtered, ini = fs_filter_fragments(chunk_in, chunk_out)
        counts_passed.append(passed)
        counts_filtered.append(filtered)
        counts_ini.append(ini)
    # logs
    logger.info("NOW UPDATING LOGS...")
    logs_in = sorted([x for x in subdir_fs.glob('log/*log')])
    logs_out = [f"{WD_OUT}/{subdir_fs.name}/log/{x.name}" for x in logs_in]
    logs = [(str(x), y, new_count) for x, y, new_count in zip(logs_in, logs_out, counts_passed)]
    for log_in, log_out, new_count in logs:
        fs_edit_log(log_in, log_out, new_count)
    passed = sum(counts_passed)
    filtered = sum(counts_filtered)
    initial = sum(counts_ini)
    logger.info(f"TOTAL OF FRAGMENT HITS INITIAL: {initial:,}")
    logger.info(f"TOTAL OF FRAGMENT HITS REMAINING: {passed:,} ({passed/initial:.2%})")
    logger.info(f"TOTAL OF FRAGMENT HITS FILTERED: {filtered:,} ({filtered/initial:.2%})")
    # FCC
    counts_ini = []
    counts_passed = []
    counts_filtered = []
    logger.info(f"\nPROCESSING: {subdir_fcc.name}")
    counts_passed = []
    WD_OUT_LOG_p = Path(f"{WD_OUT}/{subdir_fcc.name}/log")
    if not WD_OUT_LOG_p.exists():
        WD_OUT_LOG_p.mkdir(parents=True, exist_ok=True)
    # data
    logger.info("NOW UPDATING DATA...")
    chunks_in = sorted([x for x in subdir_fcc.glob('data/*')])
    chunks_out = [f"{WD_OUT}/{subdir_fcc.name}/data/{x.name}" for x in chunks_in]
    chunks = [(str(x), y) for x, y in zip(chunks_in, chunks_out)]
    for chunk_in, chunk_out in chunks:
        passed, filtered, ini = fc_filter_fragments(chunk_in, chunk_out)
        counts_passed.append(passed)
        counts_filtered.append(filtered)
        counts_ini.append(ini)
    # logs
    logger.info("NOW UPDATING LOGS...")
    logs_in = sorted([x for x in subdir_fcc.glob('log/*log')])
    logs_out = [f"{WD_OUT}/{subdir_fcc.name}/log/{x.name}" for x in logs_in]
    logs = [(str(x), y, new_count) for x, y, new_count in zip(logs_in, logs_out, counts_passed)]
    for log_in, log_out, new_count in logs:
        fc_edit_log(log_in, log_out, new_count)
    passed = sum(counts_passed)
    filtered = sum(counts_filtered)
    initial = sum(counts_ini)
    logger.info(f"TOTAL OF FRAGMENT COMBINATIONS INITIAL: {initial:,}")
    logger.info(f"TOTAL OF FRAGMENT COMBINATIONS REMAINING: {passed:,} ({passed/initial:.2%})")
    logger.info(f"TOTAL OF FRAGMENT COMBINATIONS FILTERED: {filtered:,} ({filtered/initial:.2%})")

    d1 = datetime.now()
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d1-d0}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
