#!/usr/bin/env python3
import argparse
import itertools
import math
import os

import matplotlib
import stillsuit
from ligo import segments as ligo_segments

matplotlib.use("agg")
from cycler import cycler
from matplotlib import pyplot

matplotlib.rcParams.update(
    {
        "font.size": 12.0,
        "axes.titlesize": 12.0,
        "axes.labelsize": 12.0,
        "xtick.labelsize": 12.0,
        "ytick.labelsize": 12.0,
        "legend.fontsize": 8.0,
        "figure.dpi": 300,
        "figure.figsize": (8, 4),
        "savefig.dpi": 300,
        "path.simplify": True,
        "font.family": "serif",
        "axes.prop_cycle": cycler("color", ["r", "g", "b", "c", "m", "orange", "aqua"]),
    }
)


def process_segments(segments):
    """
    get the segments in a dictionary by on ifo combination
    """
    segments = ligo_segments.segmentlistdict(
        {
            k: ligo_segments.segmentlist(
                [ligo_segments.segment(*x) for x in segments[k]]
            )
            for k in segments
        }
    )
    ifos = frozenset(sorted(segments))
    combos = []
    level = len(ifos)
    while level >= 1:
        combos.extend([frozenset(x) for x in itertools.combinations(ifos, level)])
        level -= 1
    out = {}
    for combo in combos:
        out[combo] = segments.intersection(combo) - segments.union(ifos - combo)
    return ligo_segments.segmentlistdict(out)


def process_missed_found(_missed, _found, _segments):
    """
    get missed and found instruments by on ifos
    FIXME I am sure this is stupidly slow
    """
    missed = {c: [] for c in _segments}
    found = {c: [] for c in _segments}
    for f in _found:
        for combo in _segments:
            if f["simulation"]["geocent_end_time"] in _segments[combo]:
                found[combo].append(f)
    for m in _missed:
        for combo in _segments:
            if m["geocent_end_time"] in _segments[combo]:
                missed[combo].append(m)

    return missed, found


def parse_command_line():
    parser = argparse.ArgumentParser(
        prog="plot-sim",
        description="This makes a missed found plot",
        epilog="I really hope you enjoy this program.",
    )
    parser.add_argument("-s", "--config-schema", help="config schema yaml file")
    parser.add_argument("--input-db", help="the input database.")
    parser.add_argument(
        "--output-name", help="the database to insert into. should not exist"
    )
    parser.add_argument(
        "--far-threshold",
        default=1 / 86400 / 365.25,
        type=float,
        help="FAR threshold in Hz. Default 1/86400/365.25",
    )
    parser.add_argument("-v", "--verbose", help="be verbose", action="store_true")
    args = parser.parse_args()

    return args


def main():
    args = parse_command_line()

    indb = stillsuit.StillSuit(config=args.config_schema, dbname=args.input_db)

    segmentsdict = process_segments(indb.get_segments(name="afterhtgate"))

    misseddict, founddict = process_missed_found(
        *indb.get_missed_found(
            selection_func=lambda r: r["event"]["far"] <= args.far_threshold
        ),
        segmentsdict
    )

    fig = pyplot.figure()
    for combo in misseddict:
        missed = misseddict[combo]
        found = founddict[combo]

        # Calculate decisive SNR
        if len(combo) > 1:
            missed_snr = [sorted([m["snr_%s" % c] for c in combo])[-2] for m in missed]
            found_snr = [
                sorted([f["simulation"]["snr_%s" % c] for c in combo])[-2]
                for f in found
            ]
        else:
            missed_snr = [m["snr_%s" % combo[0]] for m in missed]
            found_snr = [f["simulation"]["snr_%s" % combo[0]] for f in found]

        pyplot.semilogy(
            [m["geocent_end_time"] / 1e9 for m in missed],
            missed_snr,
            color="black",
            marker="x",
            linestyle="None",
        )
        pyplot.semilogy(
            [f["simulation"]["geocent_end_time"] / 1e9 for f in found],
            found_snr,
            marker="o",
            label=",".join(combo),
            linestyle="None",
        )
        pyplot.xlabel("Time")
        pyplot.ylabel("Decisive snr")
        pyplot.grid()
    pyplot.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
    fig.tight_layout()
    pyplot.savefig(args.output_name)


if __name__ == "__main__":
    main()
