#!/usr/bin/env python3


# imports
from importlib import import_module
from random import uniform
import argparse
import sys
import os


# auxiliary function to remove package modules from cache, to allow monkey patching cosmological functions
# adapted from https://medium.com/@chipiga86/python-monkey-patching-like-a-boss-87d7ddb8098e
def uncache(exclude):
    """Remove package modules from cache except excluded ones.
    On next import they will be reloaded.

    Args:
    exclude (iter<str>): Sequence of module paths.
    """
    pkgs = []
    for mod in exclude:
        pkg = mod.split('.', 1)[0]
        pkgs.append(pkg)

    to_uncache = []
    for mod in sys.modules:
        if mod in exclude:
            continue

        if mod in pkgs:
            to_uncache.append(mod)
            continue

        for pkg in pkgs:
            if mod.startswith(pkg + '.'):
                to_uncache.append(mod)
                break

    for mod in to_uncache:
        del sys.modules[mod]


# debug subcommand
def debug(args):
    output = args.output
    distribution = args.distribution
    error = args.error

    # don't print images to stdout
    if output == sys.stdout:
        output = None

    # LISA
    if args.debug == "LISA":
        if distribution:
            gwc.LISA_dist(output=output)
        if error:
            gwc.LISA_error(output=output)

    # ET
    elif args.debug == "ET":
        if distribution:
            gwc.ET_dist(output=output)
        if error:
            gwc.ET_error(output=output)

    # LIGO
    elif args.debug == "LIGO":
        if distribution:
            gwc.LIGO_dist(output=output)
        if error:
            gwc.LIGO_error(output=output)

    return


# plot subcommand
def plot(args):
    output = args.output
    input = args.input
    legend = eval(args.legend) if args.legend else None
    theoretical = args.theoretical

    # avoid printing the image file
    if output == sys.stdout:
        output = None

    fargs = ()
    i = 0
    for file in input:
        redshifts, distances, errors = gwc.load(file)

        if legend:
            label = legend[i]
            i += 1
        else:
            label = file.split("/")[-1].replace(".csv", "")

        fargs += (redshifts, distances, errors, label)

    gwc.plot(*fargs, theoretical=theoretical, output=output)

    return


# generate subcommand
def generate(args):
    cosmology = args.cosmology
    output = args.output

    # generic information
    info = f"## generated by gwcatalog (v.{gwc.__version__})\n"

    # info regarding cosmological model
    if not args.generate == "GWTC":
        # check if custom cosmology was used
        if cosmology:
            info += f"# cosmology: {description}\n"
        else:
            info += f"# cosmology: {description}\n"

    # pull down data from the GWTC
    if args.generate == "GWTC":
        redshifts, distances, errors = gwc.GWTC()

        info += f"# data source: GWTC 1, 2, 2.1 and 3\n# adaptations: propagated redshift error to the luminosity distance, which is then set to be symmetric\n"

    # generate a catalog for LIGO
    elif args.generate == "LIGO":
        events = args.events
        ideal = args.ideal
        redshifts = eval(args.redshifts) if args.redshifts else []

        redshifts, distances, errors = gwc.LIGO(events=events, redshifts=redshifts, ideal=ideal)

        if events:
            info += f"# observatory: LIGO (forecast)\n# event type: compact binaries\n# number of events: {events}\n"
        elif redshifts:
            info += f"# observatory: LIGO (forecast)\n# event type: compact binaries\n# redshifts provided by the user: {redshifts}\n"

    # generate a catalog for LISA
    elif args.generate == "LISA":
        population = args.population
        years = args.years
        events = args.events
        redshifts = eval(args.redshifts) if args.redshifts else []
        ideal = args.ideal

        redshifts, distances, errors = gwc.LISA(population=population, events=events, years=years, redshifts=redshifts, ideal=ideal)

        if years:
            info += f"# observatory: LISA (forecast)\n# event type: MBHB (population {population})\n# mission lifetime: {years} year(s)\n"
        elif events:
            info += f"# observatory: LISA (forecast)\n# event type: MBHB (population {population})\n# events: {events}\n"
        elif redshifts:
            info += f"# observatory: LISA (forecast)\n# event type: MBHB (population {population})\n# redshifts provided by the user: {redshifts}\n"

    # generate a catalog for the ET
    elif args.generate == "ET":
        events = args.events
        redshifts = eval(args.redshifts) if args.redshifts else []
        ideal = args.ideal

        redshifts, distances, errors = gwc.ET(events=events, redshifts=redshifts, ideal=ideal)

        if events:
            info += f"# observatory: ET (forecast)\n# event type: BNSs\n# number of events: {events}\n"
        elif redshifts:
            info += f"# observatory: ET (forecast)\n# event type: BNSs\n# redshifts provided by the user: {redshifts}\n"

    # save information on the usage of the ideal flag
    if args.generate != "GWTC":
        info += f"# ideal distribution: {ideal}\n"

    # output the catalog
    gwc.save(redshifts, distances, errors, output, info=info)

    return


# main
def main(args):
    # auxiliary global variable to hold the description of the cosmological model being used
    global description

    # import custom cosmology, if provided, using monkey patching
    if args.cosmology:
        from gwcatalog import cosmology
        os.system("mkdir /tmp/gwcatalog")
        os.system(f"cp {args.cosmology} /tmp/gwcatalog/model.py")
        sys.path.append(f"/tmp/gwcatalog")
        module = import_module("model")
        cosmology.H = module.H
        cosmology.dL = module.dL
        try:
            description = module.description.replace("\n", "")
        except:
            description = "custom"
        sys.path.pop()
        os.system("rm -r /tmp/gwcatalog")
        uncache(["gwcatalog.cosmology"])
    else:
        description = "ΛCDM (Ωₘ = 0.284, h = 0.7)"

    # import gwcatalog to global namespace after monkey patching the cosmology
    global gwc
    import gwcatalog as gwc

    # check which subcommand was provided
    if args.subcommand == "generate":
        generate(args)
    elif args.subcommand == "plot":
        plot(args)
    elif args.subcommand == "debug":
        debug(args)

    return


# run if called
if __name__ == "__main__":
    # epilog for all parsers
    epilog = "Documentation, bug reports, suggestions and discussions at:\nhttps://github.com/jpmvferreira/gwcatalog"

    # create the top level parser
    parser = argparse.ArgumentParser(epilog=epilog)

    # create global arguments in its own group
    global_group = parser.add_argument_group("Global arguments")
    global_group.add_argument("-c", "--cosmology", help="Provide a different cosmology. Input must be a Python script with the Hubble function H(z) and the luminosity distance dL(z, H).")
    global_group.add_argument("-o", "--output", help="Output the results to the provided file.", default=sys.stdout)

    # create subparser for sub-commands
    subcommands = parser.add_subparsers(title="Available subcommands", dest="subcommand")

    # create each sub-command as a parser
    generate_parser = subcommands.add_parser("generate", help="Generate catalogs.", epilog=epilog)
    plot_parser = subcommands.add_parser("plot", help="Plot catalogs.", epilog=epilog)
    debug_parser = subcommands.add_parser("debug", help="Show the underlying distributions or errors.", epilog=epilog)

    # sub-command: generate
    generate_subparser = generate_parser.add_subparsers(title="Available catalog types", dest="generate")

    # generate: GWTC
    generate_gwtc = generate_subparser.add_parser("GWTC", help="Generate a GWTC catalog, based on real events, with estimated redshifts using ΛCDM.", epilog=epilog)

    # generate: LIGO
    generate_ligo = generate_subparser.add_parser("LIGO", help="Generate a LIGO forecast catalog with compact binaries.", epilog=epilog)
    generate_ligo_group = generate_ligo.add_argument_group("Keyword arguments")
    generate_ligo_group.add_argument("-e", "--events", type=int, help="Number of events in the catalog.")
    generate_ligo_group.add_argument("-r", "--redshifts", type=str, help="A Python list with the redshift of the events to generate the catalog.", default=[])
    generate_ligo_group.add_argument("-i", "--ideal", action="store_true", help="Generate a catalog such that the events are on top of the theoretical line.")

    # generate: LISA
    generate_lisa = generate_subparser.add_parser("LISA", help="Generate a LISA forecast catalog with MBHBs.", epilog=epilog)
    generate_lisa_group = generate_lisa.add_argument_group("Keyword arguments")
    generate_lisa_group.add_argument("-p", "--population", type=str, help="Specify the MBHB catalog population. Available populations are: No Delay, Delay and Pop III.", required=True)
    generate_lisa_group.add_argument("-y", "--years", type=float, help="Number of years to generate the catalog.", default=0)
    generate_lisa_group.add_argument("-e", "--events", type=int, help="Number of events to generate the catalog.", default=0)
    generate_lisa_group.add_argument("-r", "--redshifts", type=str, help="A Python list with the redshift of the events to generate the catalog.", default=[])
    generate_lisa_group.add_argument("-i", "--ideal", action="store_true", help="Generate a catalog such that the events are on top of the theoretical line.")

    # generate: ET
    generate_et = generate_subparser.add_parser("ET", help="Generate a ET forecast catalog with BNSs.", epilog=epilog)
    generate_et_group = generate_et.add_argument_group("Keyword arguments")
    generate_et_group.add_argument("-e", "--events", type=int, help="Number of events in the catalog.")
    generate_et_group.add_argument("-r", "--redshifts", type=str, help="A Python list with the redshift of the events to generate the catalog.", default=[])
    generate_et_group.add_argument("-i", "--ideal", action="store_true", help="Generate a catalog such that the events are on top of the theoretical line.")

    # sub-command: plot
    plot_parser_group = plot_parser.add_argument_group("Keyword arguments")
    plot_parser_group.add_argument("-i", "--input", nargs="*", help="Input .csv file(s) that contains the catalog(s) sample(s).", required=True)
    plot_parser_group.add_argument("-l", "--legend", type=str, help="A string with a Python like list with the legend for each parameter, e.g.: \"['\\catalog 1', '\\catalog 2']\". Must match the order of the input files. Defaults to file name.")
    plot_parser_group.add_argument("-t", "--theoretical", const=True, nargs="?", help="Show the luminosity distance theoretical line. Optionally provide a label (latex supported if backslash is used to escape special characters, e.g.: \$ instead of $).")

    # sub-command: debug
    debug_subparser = debug_parser.add_subparsers(title="Available catalog types", dest="debug")

    # debug: LIGO
    debug_ligo = debug_subparser.add_parser("LIGO", help="Show the compact binaries redshift distributions or the LIGO observation errors.", epilog=epilog)
    debug_ligo_group = debug_ligo.add_argument_group("Keyword arguments")
    debug_ligo_group.add_argument("-d", "--distribution", action="store_true", help="Check the underlying compact binaries redshift distributions.")
    debug_ligo_group.add_argument("-e", "--error", action="store_true", help="Check the underlying LIGO observation errors.")

    # debug: LISA
    debug_lisa = debug_subparser.add_parser("LISA", help="Show the MBHBs redshift distributions or the LISA observation errors.", epilog=epilog)
    debug_lisa_group = debug_lisa.add_argument_group("Keyword arguments")
    debug_lisa_group.add_argument("-d", "--distribution", action="store_true", help="Check the underlying MBHB redshift distributions.")
    debug_lisa_group.add_argument("-e", "--error", action="store_true", help="Check the underlying LISA observation errors.")

    # debug: ET
    debug_et = debug_subparser.add_parser("ET", help="Show the BNSs redshift distributions or the ET observation errors.", epilog=epilog)
    debug_et_group = debug_et.add_argument_group("Keyword arguments")
    debug_et_group.add_argument("-d", "--distribution", action="store_true", help="Check the underlying BNS redshift distributions.")
    debug_et_group.add_argument("-e", "--error", action="store_true", help="Check the underlying ET observation errors.")

    # get arguments
    args = parser.parse_args()

    main(args)
