#!python

"""
Script annotate_act
==========================
This script is used for annotating Fragment Combinations (FC) with preprocessed
activity data from ChEMBL.
"""

# standard
import warnings
import sys
from datetime import datetime
import logging
import pandas as pd
from pandas import DataFrame
import json
import argparse
from pathlib import Path
# chemoinformatics
import rdkit
# typing
from typing import List
# dev
import npfc
from npfc import load
from npfc import save
from npfc import utils
# disable SettingWithCopyWarning warnings
pd.options.mode.chained_assignment = None  # default='warn'


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def annotate_fc_with_act(df_fcc: DataFrame,
                         df_act: DataFrame,
                         group_on: List[str] = ['idf1', 'idf2', 'abbrev'],
                         merge_on: str = 'idm') -> DataFrame:
    """Annotate FCs with n_active, n_tot and succcess rate columns.

    To achieve this, group the FCs on the group_on columns and then merge (how='left') each group with the activity data
    on the merge_on column. The percentage of the ratio active/total is computed as 'success_rate' and is shared by all
    FCs of the same group.
    For each group a DataFrame is thus created and finally concatenated.

    For groups without any activity data, the success rate is set to -1.

    :param df_fcc: a FC DataFrame to annotate
    :param df_act: a activity DataFrame
    :param group_on: a list of columns in the FC DataFrame for defining groups of FCs
    :param merge_on: a column for merging the groups of FCs and the activity data
    :return: an annotated FC DataFrame
    """
    dfs = []
    for gid, g in df_fcc.groupby(group_on):
        df = g.merge(df_act, how='left', on=merge_on).fillna(0)
        df['n_active'] = df['n_active'].sum()
        tot = df['n_tot'].sum()
        df['n_tot'] = tot

        if tot > 0:
            df['success_rate'] = (df['n_active'] / df['n_tot']) * 100
        else:
            df['success_rate'] = -1
        dfs.append(df)

    df = pd.concat(dfs)
    return df


def main():

    # init
    d0 = datetime.now()
    description = """Script is used for annotating Fragments Combinations (FC)
    with preprocessed activity data from the ChEMBL.


    It takes two inputs:
        - a file with the FC to annotate
        - a file with preprocessd activity data with 3 columns:
            - idm: identifier of a molecule
            - n_active: number of active data points found for a molecule
            - n_tot: total number of data points found for a molecule

    It creates one output:
        - the output FC file with activity data.

    When labelling FCs, the data is merged with FCs and then FCs are grouped by idf1, idf2 and abbrev.
    This computes a global n_active and n_tot for all data points involved in each FC.

    Three new columns are thus appended to the original FC file:
        - n_active: the number of data points labelled as active for the corresponding FC
        - n_tot: the total number of data points for the corresponding FC
        - success_rate: the ratio active/inactive as percentage

    This percentage can then be used as filter to identifying the "best FCs" and the n_tot value
    can be used to add weights to filter out statistically unreliable data.

    It uses the installed npfc libary in your favorite env manager.

    Example:

        >>> annotate_act input_fcc.csv.gz input_act.csv.gz output_fcc.csv.gz

    """

    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('input_fcc', type=str, default=None, help="FC file to annotate with activity file.")
    parser.add_argument('input_act', type=str, default=None, help="Activity file.")
    parser.add_argument('output_fcc', type=str, default=None, help="Annotated FC file.")
    parser.add_argument('-g', '--group-on', type=str, default="['idf1', 'idf2', 'abbrev']", help='List of columns for grouping FCs.')
    parser.add_argument('-m', '--merge-on', type=str, default="idm", help='Column for merging FCs with activity data. This is a left join because some combinations do not have any activity data.')
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    logger.info("RUNNING ACT ANNOTATION")
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    pad = 40

    # parse arguments

    # i/o
    utils.check_arg_input_file(args.input_fcc)
    utils.check_arg_input_file(args.input_act)
    utils.check_arg_output_file(args.output_fcc)
    input_fcc_format, input_fcc_compression = utils.get_file_format(args.input_fcc)
    input_act_format, input_act_compression = utils.get_file_format(args.input_act)
    output_fcc_format, output_fcc_compression = utils.get_file_format(args.output_fcc)
    # group_on
    if not isinstance(args.group_on, str):
        raise AttributeError(f"ERROR! ATTRIBUTE GROUP_ON SHOULD BE OF TYPE STRING BUT RECEIVED '{type(args.group_on)} INSTEAD ({args.group_on})")
    elif '[' not in args.group_on or ']' not in args.group_on:
        raise ValueError(f"ERROR! ATTRIBUTE GROUP_ON DOES NOT REPRESENT A LIST (NO '[' OR ']' IN {args.group_on}")
    group_on = json.loads(args.group_on.strip().replace("'", '"'))
    # merge_on
    if args.merge_on is None:
        raise ValueError(f"ERROR! ATTRIBUTE MERGE_ON CANNOT BE NONE")

    # display infos
    logger.info("LIBRARY VERSIONS:")
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS:")
    logging.info("INPUT_FCC".ljust(pad) + f"{args.input_fcc}")
    logging.info("INPUT_FCC_FORMAT".ljust(pad) + f"{input_fcc_format}")
    logging.info("INPUT_FCC_COMPRESSION".ljust(pad) + f"{input_fcc_compression}")
    logging.info("INPUT_ACT".ljust(pad) + f"{args.input_act}")
    logging.info("INPUT_ACT_FORMAT".ljust(pad) + f"{input_act_format}")
    logging.info("INPUT_ACT_COMPRESSION".ljust(pad) + f"{input_act_compression}")
    logging.info("OUTPUT_FCC".ljust(pad) + f"{args.output_fcc}")
    logging.info("OUTPUT_FCC_FORMAT".ljust(pad) + f"{output_fcc_format}")
    logging.info("OUTPUT_FCC_COMPRESSION".ljust(pad) + f"{output_fcc_compression}")
    logging.info("LOG".ljust(pad) + f"{args.log}")

    # begin
    logging.info("BEGIN")

    # load input_fcc
    logging.info("LOADING FC FILE")
    d1 = datetime.now()
    df_fcc = load.file(args.input_fcc, decode=True)
    # df_map.drop("Unnamed: 0", axis=1, inplace=True)
    logging.info(f"FOUND {len(df_fcc.index)} FRAGMENT COMBINATIONS")

    # load input_act
    logging.info("LOADING ACTIVITY FILE")
    d2 = datetime.now()
    df_act = load.file(args.input_act, decode=False)
    logging.info(f"FOUND {len(df_act.index)} ACTIVITY DATA POINTS")

    # mapping fragment combinations
    logger.info(f"ANNOTATING FRAGMENT COMBINATION")
    d3 = datetime.now()
    df_fcc = annotate_fc_with_act(df_fcc, df_act, group_on=group_on, merge_on=args.merge_on)

    # saving results
    logger.info("SAVING RESULTS")
    d4 = datetime.now()
    logger.info(f"SAVING ANNOTATED FCs AT '{args.output_fcc}'")
    save.file(df_fcc, args.output_fcc, encode=True)

    # end
    d5 = datetime.now()
    logging.info("SUMMARY")
    logger.info("COMPUTATIONAL TIME: CONFIGURING JOB".ljust(pad * 2) + f"{d1-d0}")
    logger.info("COMPUTATIONAL TIME: LOADING INPUT FCs".ljust(pad * 2) + f"{d2-d1}")
    logger.info("COMPUTATIONAL TIME: LOADING INPUT ACT".ljust(pad * 2) + f"{d3-d2}")
    logger.info("COMPUTATIONAL TIME: ANNOTATING FCs".ljust(pad * 2) + f"{d4-d3}")
    logger.info("COMPUTATIONAL TIME: SAVING RESULTS".ljust(pad * 2) + f"{d5-d4}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{d5-d0}")
    logging.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)
