#!/usr/bin/env python3

"""
Command line script for running raxml-ng for maximum likelihood phylogenetics
and ancestral sequence reconstruction.
"""
__author__ = "Michael J. Harms"
__date__ = "2021-08-12"

import topiary

import sys, argparse, inspect, os, shutil
import multiprocessing as mp

def main(argv):
    """
    Parse command line arguments.
    """

    # Construct main parser
    parser = argparse.ArgumentParser(prog="run-raxml",
                description='Run raxml-ng for ancestral reconstruction etc.')

    # Add mode positional argument
    parser.add_argument("mode",nargs=1,help="mode. must be ml, model, or anc")

    # Arguments shared by all modes
    parser.add_argument("-c","--csv",nargs=1,help="topiary csv file (.csv)")
    parser.add_argument("-t","--tree",nargs=1,help="tree file (newick)")
    parser.add_argument("-m","--model",nargs=1,help="phylogenetic model")
    parser.add_argument("-o","--output",nargs=1,help="name for outputs")
    parser.add_argument("-T","--threads",nargs=1,help="number of threads. if -1 use all available.",type=int,
                        default=-1)
    parser.add_argument("-b","--binary",nargs=1,help="raxml binary",
                        default=topiary.raxml.RAXML_BINARY)

    # Add anc-mode specific arguments
    parser.add_argument("--anc-alt-pp",nargs=1,
                        help="anc mode: alternate ancestor posterior probability cutoff",
                        type=float,
                        default=0.25)
    parser.add_argument("--anc-support-tree",nargs=1,
                        help="anc mode: tree with supports",
                        type=str)

    parser.add_argument("--force",action='store_true',help="overwrite existing output")

    # Parse arguments
    args = parser.parse_args(argv)

    # Which functions to use for each mode
    modes = {"model":topiary.raxml.find_best_model,
             "ml":topiary.raxml.generate_ml_tree,
             "anc":topiary.raxml.generate_ancestors}

    # Figure out what function to run
    mode = args.mode[0]
    try:
        func = modes[mode]
    except KeyError:
        err = f"\n\nmode {mode} not recognized. should be one of:\n"
        for m in modes.keys():
            err += f"    {m}\n"
        err += "\n\n"
        raise ValueError(err)

    # Get kwarg arguments specified
    kwargs = {}

    # tree file
    tree_file = args.tree
    if tree_file is not None:
        kwargs["tree_file"] = tree_file[0]

    # csv file
    csv_file = args.csv
    if csv_file is not None:
        kwargs["df"] = csv_file[0]

    # phylogenetic model
    model = args.model
    if model is not None:
        kwargs["model"] = model[0]

    # output directory
    output = args.output
    if output is not None:
        kwargs["output"] = output[0]

    # number of threads
    threads = args.threads
    if threads is not None:
        if type(threads) is int:
            kwargs["threads"] = threads
        else:
            kwargs["threads"] = threads[0]

        if kwargs["threads"] == -1:
            kwargs["threads"] = mp.cpu_count()

    # raxml binarys
    binary = args.binary
    if binary is None:
        binary = topiary.raxml.RAXML_BINARY
    kwargs["raxml_binary"] = binary

    # Ancestor-specific arguments
    if mode == "anc":
        kwargs["alt_cutoff"] = args.anc_alt_pp

    if mode == "anc":
        if args.anc_support_tree is not None:
            kwargs["tree_file_with_supports"] = args.anc_support_tree[0]

    # Figure out what kwargs are required and allowed for the function being
    # called
    required = []
    param = inspect.signature(func).parameters
    for p in param:
        if param[p].default is param[p].empty:
            required.append(p)
    allowed = list(param)

    # Figure out if there are extra kwargs
    extra = []
    for k in kwargs:
        try:
            required.remove(k)
        except ValueError:
            pass

        if k not in allowed:
            extra.append(k)

    # Check for extra arguments
    err = ""
    if len(extra):
        err += f"\n\nThe following arguments are not compatible with mode {mode}:\n"
        for e in extra:
            err += f"    {e}\n"

    # Check to make sure all required arguments are specified
    if len(required):
        err += f"\n\nThe following arguments are required for mode {mode}:\n"
        for r in required:
            err += f"    {r}\n"

    # If an error has occurred, report it
    if err != "":
        err += "\n\n"
        raise ValueError(err)

    if args.force:
        if os.path.exists(kwargs["output"]):
            shutil.rmtree(kwargs["output"])

    # Call actual function
    func(**kwargs)

# If called from the command line
if __name__ == "__main__":
    main(sys.argv[1:])
