#!python

"""
Script run_protocol
==========================
This script is used for running the NPFC protocols.
"""

# standard
from pathlib import Path
import json
import warnings
import sys
import subprocess
from datetime import datetime
from math import ceil
import logging
import pandas as pd
import argparse
import pkg_resources
# chemoinformatics
import rdkit
# dev
import npfc
from npfc import utils
from npfc import load


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FUNCTIONS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


def count_required_chunks(input_file, chunksize):
    """Count the number of required chunks for a given input file, depending on
    the chunk size.
    """

    # count mols for natural and synthetic
    # for counting mols, need to process the unzipped file
    input_file_uncompressed = input_file.split('.gz')[0]
    # count mols + uncompress input file
    num_mols = load.count_mols(input_file, keep_uncompressed=True)
    # uncompressed input file works not in smk file (makes the whole pipeline to overwrite any output) so just discard it. The archive will be unziped again.
    Path(input_file_uncompressed).unlink()
    # determine the number of chunks to generate
    num_chunks = estimate_num_chunks(num_mols, chunksize)
    return {'num_mols': num_mols, 'num_chunks': num_chunks, 'chunksize': chunksize}


def estimate_num_chunks(num_mols, chunksize):
    return ceil(num_mols / chunksize)


def main():

    # init
    ds = []
    ds.append(('START', datetime.now()))
    description = """This script is used for running a NPFC protocol, using Snakemake locally or in a cluster environment.

    3 protocols are setup by default: 'fragments', 'natural', 'synthetic'.

    These keywords actually refer to hard-coded snakefiles addressing tasks of this project.

    If you would like to run another snakefile, you just have to provide a file. On par with the snakefile, one needs to provide
    a configuration file in JSON format with the expected parameters. Any parameter found in that file will be transferred to the snakefile, which
    is particularly useful for avoiding hard-coded paths, etc. and make the snakefiles more reusable.

    It is possible to override the configuration file parameters by using command line arguments, although only the ones I actually use in this project
    are updated for now. It would be possible to add an argument params with a string of type key1=value1,key2=value2, etc., but it is not implemented (yet).

    This script uses the installed npfc libary in your favorite env manager.

    """
    # parameters CLI
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('protocol', type=str, default=None, help="Protocol to apply. If this is not a file, then a lookup for matching the string with known protocols is performed. Possible values are: fragments/natural/synthetic.")
    parser.add_argument('--clusterconfig', type=str, default=None, help="For running on a cluster, specify the cluster configuration file for snakemake.")
    parser.add_argument('-j', '--jobs', type=int, default=6, help="Number of jobs that can be run simultaneously when workflow is run locally.")
    parser.add_argument('--log', type=str, default='INFO', help="Specify level of logging. Possible values are: CRITICAL, ERROR, WARNING, INFO, DEBUG.")
    parser.add_argument('-c', '--config', type=str, default=None, help="JSON configuration file for setting up default variables when the running the protocol. Arguments below overwrite parameters specified in the config file.")
    parser.add_argument('-w', '--wd', type=str, default=None, help="Working directory containing the data folder, where output subfolders and files are stored.")
    parser.add_argument('-p', '--prefix', type=str, default=None, help="Default prefix to use for all output files.")
    parser.add_argument('-m', '--molid', type=str, default=None, help="SDF property to use for setting idm.")
    parser.add_argument('-i', '--input_file', type=str, default=None, help="SDF with molecules to process.")
    parser.add_argument('-f', '--frags_file', type=str, default=None, help="SDF with the fragments to use for substructure seach.")
    parser.add_argument('-n', '--natref', type=str, default=None, help="Root folder of the dataset to use for natural reference when running synthetic protocol.")
    parser.add_argument('--chunksize', type=int, default=None, help="Overwrite config file chunksize property.")
    args = parser.parse_args()

    # logging
    logger = utils._configure_logger(args.log)
    pad = 40
    warnings.filterwarnings('ignore', category=pd.io.pytables.PerformanceWarning)  # if None is returned instead of a molecule, do not complain about mixed types
    logger.info("RUNNING NPFC PROTOCOL")

    # parse arguments
    errors = []
    # protocol
    if args.protocol not in ('fragments', 'natural', 'synthetic', 'commercial') and not Path(args.protocol).exists():
        errors.append(f"ERROR! UNKNOWN PROTOCOL ({args.protocol})!")  # error if file does not exist or if protocol is not recorded as package data

    # config
    if args.config is not None:
        if not Path(args.config).exists():
            errors.append(f"ERROR! CONFIG FILE FOR PROTOCOL '{args.protocol}' COULD NOT BE FOUND AT '{args.config}'!")

    # display general infos
    logger.info("LIBRARY VERSIONS".center(pad))
    logger.info("rdkit".ljust(pad) + f"{rdkit.__version__}")
    logger.info("pandas".ljust(pad) + f"{pd.__version__}")
    logger.info("npfc".ljust(pad) + f"{npfc.__version__}")
    logger.info("ARGUMENTS".center(pad))
    logger.info("PROTOCOL".ljust(pad) + f"{args.protocol}")
    logger.info("CONFIG".ljust(pad) + f"{args.config}")
    logger.info("CLUSTERCONFIG".ljust(pad) + f"{args.clusterconfig}")
    logger.info("JOBS".ljust(pad) + f"{args.jobs}")

    # default keys set to None
    config = ['molid', 'prefix', 'WD', 'input_file', 'frags_file', 'chunksize', 'natref']
    config = {k: None for k in config}
    # special case for tautomers
    config['tautomer'] = False

    # begin
    ds.append(('CONFIGURE', datetime.now()))
    logger.info("BEGIN")

    # configure job by updating config with config_parsed
    if args.config is not None:
        config_file = args.config
        logger.info("PARSING SPECIFIED CONFIG FILE")
    else:
        config_file = pkg_resources.resource_filename('npfc', f"data/{args.protocol}.json")
        logger.info(f"NO SPECIFIED CONFIG FILE, PARSING DEFAULT CONFIG FILE AT '{config_file}'")
    with open(config_file, 'r') as CONFIG:
        config_parsed = json.load(CONFIG)
    config.update(config_parsed)

    # check arguments

    # input_file
    if args.input_file is not None:
        logger.info(f"OVERWRITING INPUT_FILE TO {args.input_file}")
        config['input_file'] = args.input_file
    if not utils.check_arg_input_file(config['input_file']):
        errors.append(f"ERROR! INPUT_FILE COULD NOT BE FOUND AT '{config['input_file']}'!")

    # chunksize
    if args.chunksize is not None:
        logger.info(f"OVERWRITING CHUNKSIZE TO {args.chunksize}")
        config['chunksize'] = args.chunksize
    if not utils.check_arg_positive_number(config['chunksize']):
        errors.append(f"ERROR! ILLEGAL VALUE FOR  CHUNKSIZE ({config['chunksize']})!")

    # wd
    if args.wd is not None:
        logger.info(f"OVERWRITING WD TO {args.wd}")
        config['WD'] = args.wd
    if not utils.check_arg_input_dir(config['WD']):
        errors.append(f"ERROR! WD COULD NOT BE FOUND AT '{config['WD']}'!")

    # prefix
    if args.prefix is not None:
        logger.info(f"OVERWRITING PREFIX TO {args.prefix}")
        config['prefix'] = args.prefix
    if config['prefix'] is None:
        errors.append("ERROR! PREFIX IS UNDEFINED!")

    # frags_file
    if args.frags_file is not None:
        logger.info(f"OVERWRITING PREFIX TO {args.frags_file}")
        config['frags_file'] = args.frags_file
    if config['frags_file'] is not None and not utils.check_arg_input_file(config['frags_file']):
        errors.append(f"ERROR! FRAGS_FILE COULD NOT BE FOUND AT '{config['frags_file']}'!")

    # natref
    if args.natref is not None:
        logger.info(f"OVERWRITING WD TO {args.natref}")
        config['natref'] = args.natref
    if config['natref'] is not None:
        if not utils.check_arg_input_dir(config['natref']):
            errors.append(f"ERROR! natref COULD NOT BE FOUND AT '{config['natref']}'!")

    # env
    if args.clusterconfig is not None and not Path(args.clusterconfig).exists():
        errors.append(f"ERROR! CLUSTERCONFIG WAS SPECIFIED BUT COULD NOT BE FOUND AT '{args.cluster}'!")

    # jobs
    if args.jobs < 1:
        errors.append(f"ERROR! UNSUITABLE VALUE FOR JOBS ({args.jobs})")

    # exit on error(s)
    n_errors = len(errors)
    if n_errors > 0:
        logger.error(f"{n_errors} ERROR(S) FOUND DURING ARGUMENT PARSING:\n" + '\n'.join(errors) + "\n")
        sys.exit(1)

    logger.info(f"PROTOCOL - {args.protocol}".center(pad))
    # automatically retrieve all variables and append them to command string using key=value
    if Path(args.protocol).is_file():
        smk_file = args.protocol
    else:
        try:
            smk_file = pkg_resources.resource_filename('npfc', f"data/fc/fc_{args.protocol}.smk")
        except ValueError:   # ### no tested
            logger.error(f"ERROR! PROTOCOL '{args.protocol}' COULD NOT BE FOUND EITHER IN STORED PROTOCOLS OR AT LOCATED FILE!")
    # configure current smk command
    smk_command = f"snakemake -k -s {smk_file} --jobname fc_{args.protocol}" + "_{rulename}.{jobid} "

    # determine environment
    if args.clusterconfig is None:
        smk_command += f"-j {args.jobs} "
    else:
        smk_command += '-j 100 --cluster-config ' + args.clusterconfig + ' --cluster "sbatch -A {cluster.account} -p {cluster.partition} -n {cluster.n}  -t {cluster.time} --oversubscribe" '

    # arguments
    if 'fragments' in args.protocol:
        protocol = 'fragments'
    elif 'natural' in args.protocol:
        protocol = 'natural'
    elif 'synthetic' in args.protocol:
        protocol = 'synthetic'
    elif 'commercial' in args.protocol:
        protocol = 'commercial'
    elif 'other' in args.protocol:
        protocol = 'other'
    else:
        raise ValueError(f"UNKNOWN PROTOCOL! ('{args.protocol}')")

    # add specific arguments when chunks are needed
    if protocol in ['natural', 'synthetic', 'commercial', 'other']:
        # estimate the number of chunks required based on the number of molecules in the input file
        num_mols_file = f"{config['WD']}/data/00_raw/data/{config['prefix']}_num_mols.json"
        if Path(num_mols_file).exists():
            logger.info(f"FOUND NUM_MOLS FILE AT '{num_mols_file}'")
            with open(num_mols_file, 'r') as FILE:
                d = json.load(FILE)
                [logger.info(f'OPT {k}: {v}' for k, v in d.items())]
                if config['chunksize'] != d['chunksize']:
                    logger.warning("THE OPTION CHUNKSIZE IS DIFFERENT IN NUM_MOLS AND JSON CONFIGURATION! USING THE OPTION IN JSON CONFIGURATION INSTEAD.")
                    num_chunks = estimate_num_chunks(d['num_mols'], d['chunksize'])
                    d['num_chunks'] = num_chunks
        else:
            # count mols for natural and synthetic
            logger.info(f"DID NOT FOUND ANY NUM_MOLS FILE AT '{num_mols_file}', COMPUTING IT...")
            d = count_required_chunks(config['input_file'], config['chunksize'])
            with open(num_mols_file, 'w+') as FILE:
                json.dump(d, FILE)
        # add the counts of mols/chunks to the configuration
        config.update(d)
        logger.info("NUMBER OF CHUNKS:".ljust(pad) + f"{config['num_mols']:,d} / {config['chunksize']:,d} = {config['num_chunks']:,d}")

    # add arguments to the smk command
    smk_command += "--config "
    for k, v in config.items():
        if v is not None:
            if k == 'protocol_std' and v == 'DEFAULT':
                if protocol == 'fragments':
                    default_protocol_std_file = pkg_resources.resource_filename('npfc', "data/std_fragments.json")
                else:
                    default_protocol_std_file = pkg_resources.resource_filename('npfc', "data/std_mols.json")
                logger.info(f"OPT {k}".ljust(pad) + f"{v} ('{default_protocol_std_file}')")
            else:
                logger.info(f"OPT {k}".ljust(pad) + f"{v}")
            smk_command += f"{k}='{v}' "
    smk_command = smk_command[:-1]

    logger.info(f"SMK COMMAND:\n{smk_command}\n")

    # draw the protocol task tree
    WD = config['WD']
    if WD.endswith('/'):
        WD = WD[:-1]
    svg = WD + f"/{protocol}_{config['prefix']}_tasktree.svg"
    smk_command_svg = smk_command + f" --rulegraph | dot -Tsvg > {svg}"
    logger.info(f"DRAWING TASK TREE AT '{svg}'")
    logger.info(f"COMMAND IS:\n{smk_command_svg}\n")
    res = subprocess.run(smk_command_svg, shell=True, check=True)
    print(f"{res=}")

    # run protocol
    logger.info("BEGIN OF SNAKEMAKE PROTOCOL")
    logger.info(f"COMMAND IS:\n{smk_command}\n")
    subprocess.run(smk_command, shell=True, check=True)
    ds.append((f"PROTOCOL '{args.protocol}'", datetime.now()))
    logger.info(f"END OF SNAKEMAKE PROTOCOL '{args.protocol}'")

    # end
    ds.append(("END", datetime.now()))
    logger.info("SUMMARY")
    for i in range(len(ds) - 1):
        if ds[i+1] != 'END':
            logger.info(f"COMPUTATIONAL TIME: {ds[i+1][0]}".ljust(pad * 2) + f"{ds[i+1][1] - ds[i][1]}")
    logger.info("COMPUTATIONAL TIME: TOTAL".ljust(pad * 2) + f"{ds[-1][1] - ds[0][1]}")
    logger.info("END")


# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MAIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #


if __name__ == '__main__':
    main()
    sys.exit(0)


    # wildcards in jobnames
    # snakemake -k -s /home/gally/Software/anaconda3/envs/npfc3/lib/python3.8/site-packages/npfc/data/fct/fct_commercial.smk
    # -j 100 --cluster-config /home/gally/Software/anaconda3/envs/npfc3/lib/python3.8/site-packages/npfc/data/cluster_clem.json
    # --cluster "sbatch -A {cluster.account} -p {cluster.partition} -n {cluster.n}  -t {cluster.time} --oversubscribe"
    # --config root_dir='fc/02_commercial/zinc'
    # prep_subdir='prep' prefix='zinc'
    # WD='fct/crms_dnp_chembl_zinc/data/02_commercial'
    # config_file='fc/02_commercial/zinc/test_commerical_zinc.json'
    # commercial_ref='fc/02_commercial/zinc/data/prep/04_dedupl/zinc_refs.hdf'
    # color='orange'
    # num_chunks='2'
    # --forceall
    # --jobname "c_{rulename}_{wildcards.cid}.{jobid}"

    # mail
    # echo "test" | mail -s "awesome subject" josemanuel.gally@mpi-dortmund.mpg.de
