#!python

"""
Script preprocess_act
==========================
This script is used for extracting activity data from a pre-extracted activity file.
"""

# standard
import warnings
import sys
from datetime import datetime
import logging
import pandas as pd
import argparse
from pathlib import Path
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def set_activity_score(activity):
    if activity < 5:
        return 0
    elif activity < 20:
        return 0.5
    else:
        return 1


def main():

    # init
    d0 = datetime.now()
    description = """Script used for preprocessing the raw activity data extracted
    from the ChEMBL database.

    Preprocessing consists in:
        1. labelling activities into active (1), maybe (0.5) or inactive (0) based on rules
        2. updating idms from activity file with kept_idm from ref-uni/dir-log-uni.
           This is so that filtered idms linked with activity are not lost when merging
           activity with unique molecule idms.
        3. count how many active data points are found for each molecule: n_active and n_tot.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> preprocess_act act_raw.csv.gz act.csv.gz

    """
    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_act', type=str, default=None, help="Activity file extracted from the ChEMBL.")
    parser.add_argument('output_act', type=str, default=None, help="Activity file with categorized values and ")
    parser.add_argument('ref_uni', type=str, default=None, help="")
    parser.add_argument('dir_log_uni', type=str, default=None, help="")
    parser.add_argument('--threshold', type=int, default=5, help="Minimum value for pchembl_value for labelling an activity as active.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING ACT ANNOTATION")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parameters

    # i/o
    utils.check_arg_input_file(args.input_act)
    utils.check_arg_output_file(args.output_act)
    # ref_uni
    if args.ref_uni is None:
        raise ValueError(f"ERROR! REF-UNI AND DIR-LOG-UNI SHOULD BE EITHER BOTH DEFINED OR UNDEFINED")
    else:
        utils.check_arg_input_file(args.ref_uni)
    # dir_log_uni
    if args.dir_log_uni is None:
        raise ValueError(f"ERROR! REF-UNI AND DIR-LOG-UNI SHOULD BE EITHER BOTH DEFINED OR UNDEFINED")
    else:
        utils.check_arg_input_dir(args.dir_log_uni)
    #
    input_act_format, input_act_compression = utils.get_file_format(args.input_act)
    ref_uni_format, ref_uni_compression = utils.get_file_format(args.ref_uni)
    output_act_format, output_act_compression = utils.get_file_format(args.output_act)

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logger.info("INPUT_ACT".ljust(pad) + f"{args.input_act}")
    logger.info("INPUT_ACT_FORMAT".ljust(pad) + f"{input_act_format}")
    logger.info("INPUT_ACT_COMPRESSION".ljust(pad) + f"{input_act_compression}")
    logger.info("REF_UNI".ljust(pad) + f"{args.ref_uni}")
    logger.info("REF_UNI_FORMAT".ljust(pad) + f"{ref_uni_format}")
    logger.info("REF_UNI_COMPRESSION".ljust(pad) + f"{ref_uni_compression}")
    logger.info("DIR_LOG_UNI".ljust(pad) + f"{args.dir_log_uni}")
    logger.info("OUTPUT_ACT".ljust(pad) + f"{args.output_act}")
    logger.info("OUTPUT_ACT_FORMAT".ljust(pad) + f"{output_act_format}")
    logger.info("OUTPUT_ACT_COMPRESSION".ljust(pad) + f"{output_act_compression}")
    logger.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logging.info("BEGIN")

    # load act data
    d1 = datetime.now()
    logger.info("LOADING ACTIVITY FILE")
    df_act = pd.read_csv(args.input_act, sep='\t')
    logger.info(f"LOADED {len(df_act.index)} DATA POINTS")
    if 'Unnamed: 0' in df_act.columns:
        df_act.drop('Unnamed: 0', axis=1, inplace=True)
    logger.debug(f"EXCERPT:\n\n{df_act.head(10)}\n")
    # classify act into active/inactive
    d2 = datetime.now()
    logger.info("FILTERING ACTIVITY DATA")
    n_act_ini = len(df_act.index)

    df_act = df_act[
        (df_act["Plate"].str.contains("-10.00-")) |
        (df_act["Plate"].str.contains("-02.00-"))
    ]
    df_act = df_act[df_act['Toxic'] == False]
    n_act_rem = len(df_act.index)
    logger.info(f"NUMBER OF REMAINING ENTRIES: {n_act_rem:,}/{n_act_ini:,} ({n_act_rem/n_act_ini:.2%})")
    logger.info("CLASSIFYING ACTIVITY INTO ACTIVE/INACTIVE/MAYBE")
    df_act['active'] = df_act['Activity'].map(set_activity_score)
    n_act_1 = len(df_act[df_act['active'] == 1])
    n_act_05 = len(df_act[df_act['active'] == 0.5])
    n_act_0 = len(df_act[df_act['active'] == 0])
    logger.info(f"NUMBER OF ACTIVES: {n_act_1:,}")
    logger.info(f"NUMBER OF INACTIVES: {n_act_0:,}")
    logger.info(f"NUMBER OF MAYBE: {n_act_05:,}")


    # update act idm with kept idm from ref file
    d3 = datetime.now()
    logger.info(f"LOADING IDMS OF FILTERED/KEPT ENTRIES")

    # filtered entries
    p_dir_log_uni = Path(args.dir_log_uni)
    files_log_uni = sorted([str(x) for x in list(p_dir_log_uni.glob("*.log"))])
    logger.info(f"FOUND {len(files_log_uni)} LOG FILES")
    logger.debug(f"LOG FILES:\n{files_log_uni}\n")
    df_syn = pd.concat([pd.read_csv(x, sep="@", header=None) for x in files_log_uni])
    df_syn = df_syn[df_syn[0].str.contains("RESULT")]  # each log contains a list of RESULT lines with inchikey|idm_kept|idm_filtered
    if len(df_syn.index) > 0:
        df_syn["inchikey"], df_syn["idm_kept"], df_syn["idm_filtered"] = zip(*df_syn[0].map(lambda x: x.split(":")[-1].strip().split("|")))
        df_syn.drop(0, axis=1, inplace=True)
        df_syn.reset_index(inplace=True, drop=True)
    else:
        df_syn = pd.DataFrame({'inchikey': [], 'idm_kept': [], 'idm_filtered': []})
    logger.info(f"NUMBER OF IDM IN LOG FILES: {len(df_syn.index)}")

    # kept entries
    df_ref = load.file(args.ref_uni)
    df_ref.rename({"idm": "idm_kept"}, axis=1, inplace=True)
    df_ref["idm_filtered"] = df_ref["idm_kept"]
    logger.info(f"NUMBER OF IDM IN REF_FILE: {len(df_ref.index)}")

    # merge kept/filtered
    df_syn = pd.concat([df_syn, df_ref])
    logger.info(f"TOTAL NUMBER OF IDM: {len(df_syn.index)}")

    # merge act with idms
    logger.info("MERGING IDMS WITH ACTIVITY DATA")
    df_act = df_act.merge(df_syn, how='right', left_on='Compound_Id', right_on='idm_filtered')
    df_act.dropna(subset=['Compound_Id'], inplace=True)
    df_act.drop(['Compound_Id', 'idm_filtered', 'inchikey'], axis=1, inplace=True)
    df_act.rename({'idm_kept': 'idm'}, axis=1, inplace=True)
    logger.info(f"REMAINING NUMBER OF IDM: {len(df_act.index):,d}")

    # proccess activity data to count on many data points are active/tot for each molecule
    logger.info("REGROUPING ACTIVITY PER KEPT IDM")
    df_act = df_act.groupby('idm').agg({'active': [sum, lambda x: len(list(x))]})
    df_act.reset_index(inplace=True)
    df_act.columns = ['idm', 'n_active', 'n_tot']

    # export processed activity data
    d4 = datetime.now()
    logger.info(f"SAVING ANNOTATED FCs AT '{args.output_act}'")
    save.file(df_act, args.output_act)

    # end
    d5 = datetime.now()
    logging.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING INPUT FCs".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: LOADING INPUT ACT".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: ANNOTATING FCs".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: SAVING RESULTS".ljust(pad * 2) + f"{d5-d4}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d5-d0}")
    logging.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
