#!/bin/env python

import argparse
import distutils
import json
import os
from fractions import Fraction
from pathlib import Path

import bilby
import h5py
import numpy as np
from loguru import logger

import pygwb.pe as pe
from pygwb.baseline import Baseline
from pygwb.detector import Interferometer


def load_npz(fname_path):
    data = np.load(fname_path)
    ptEst_ff = np.real(data["point_estimate_spectrum"])
    sigma_ff = data["sigma_spectrum"]
    try:
        freq_mask = data["frequency_mask"]
    except KeyError:
        logger.info("Data file has no notches saved and no notchfile was passed; no notching will be applied.")
        freq_mask = True
        
    goodInds = np.logical_and(np.invert(np.isnan(ptEst_ff)), ptEst_ff!=0, sigma_ff!=0)
    if "frequencies" not in data.keys():
        freqs = None
    else:
        freqs = data["frequencies"]
        freqs = freqs[goodInds]
    Y_f = ptEst_ff[goodInds]
    sigma_f = sigma_ff[goodInds]
    return (freqs,Y_f,sigma_f, freq_mask)

def load_hdf5(fname_path):
    hf = h5py.File(fname_path, "r")
    point_estimate_spectrum = np.real(np.array(hf.get("point_estimate_spectrum")))
    sigma_spectrum = np.array(hf.get("sigma_spectrum"))
    freq_mask = np.array(hf.get("freqs_mask"))
    if "freqs" not in hf.keys():
        freqs = None
    else:
        freqs = np.array(hf.get("freqs"))

    return (freqs, point_estimate_spectrum, sigma_spectrum, freq_mask)
    

def main():
    pe_parser = argparse.ArgumentParser(add_help=True)
    pe_parser.add_argument(
        "--input_file", help="Path to data file (or pickled baseline) to use for analysis.", action="store", type=str, required=True
    )
    pe_parser.add_argument(
        "--ifos", help="List of names of two interferometers for which you want to run PE, default is H1 and L1. Not needed when running from a pickled baseline.", action="store", type=str, required=False, default = "None"
    )
    pe_parser.add_argument(
        "--model", help="Model for which to run the PE. \n Default is \"Power-law\". ", action="store", type=str, required=False, default = "Power-Law"
    )
    pe_parser.add_argument(
        "--model_prior_file", help="Points to a file with the required parameters for the model and their priors. \n Default value is power-law parameters omega_ref with LogUniform bilby prior from 1e-11 to 1e-8 and alpha with Uniform bilby prior from -4 to 4.", type=str, action="store", required=False, default = "None"
    )
    pe_parser.add_argument(
        "--non_prior_args", help="A dictionary with the parameters of the model that are not associated with a prior, such as the reference frequency for Power-Law. Default value is reference frequency at 25 Hz for the power-law model.", type=str, action="store", required=False, default="None"
    )
    pe_parser.add_argument(
        "--outdir", help="Output directory of the PE (sampler). Default is \"./pe_output\".", type=str, action="store", required=False, default='./pe_output'
    )
    pe_parser.add_argument(
        "--apply_notching", help="Apply notching to the data. Default is True.", action="store", type=str, required=False
    )
    pe_parser.add_argument(
        "--notch_list_path", help="Absolute path from where you can load in notch list file.", action="store", type=str, required=False
    )
    pe_parser.add_argument(
        "--injection_parameters", help="The injected parameters.", required= False, default="None", type=str, action="store"
    )
    pe_parser.add_argument(
        "--quantiles", help="The quantiles used for plotting in plot corner. Default is [0.05, 0.95].", type=str, action="store", required=False, default="None"
    )
    pe_parser.add_argument(
        "--calibration_epsilon", help="Calibration uncertainty. Default is 0.", type=float, action="store", default=None, required=False,
    )

    script_args = pe_parser.parse_known_args()[0]

    dictionary_models = {"Power-Law": pe.PowerLawModel, 
                         "Broken-Power-Law": pe.BrokenPowerLawModel,
                         "Triple-Broken-Power-Law": pe.TripleBrokenPowerLawModel,
                         "Smooth-Broken-Power-Law": pe.SmoothBrokenPowerLawModel,
                         "Schumann": pe.SchumannModel,
                         "Geodesy": pe.Geodesy,
                         "Parity-Violation": pe.PVPowerLawModel,
                         "Parity-Violation-2": pe.PVPowerLawModel2} #"TVS-Power-Law": pe.TVSPowerLawModel

    dictionary_parameters = {"Power-Law": ["omega_ref","alpha", "fref"],
                             "Broken-Power-Law": ["omega_ref", "fbreak", "alpha_1", "alpha_2"],
                             "Triple-Broken-Power-Law": ["omega_ref", "alpha_1", "alpha_2", "alpha_3", "fbreak1", "fbreak2"],
                             "Smooth-Broken-Power-Law": ["omega_ref", "fbreak", "alpha_1", "alpha_2", "delta"],
                             "Schumann": ["kappa_H1", "kappa_L1", "beta_L1", "beta_H1"],
                             "Geodesy": ["omega_ref", "alpha", "beta", "omega_det1", "omega_det2"],
                             "Parity-Violation": ["omega_ref", "alpha", "Pi"],
                             "Parity-Violation-2": ["omega_ref", "alpha", "beta"]}

    if script_args.model not in dictionary_models.keys():
        raise ValueError(
            "The model {script_args.model} you provided is not supported."
        )
    else:
        class_model = dictionary_models[script_args.model]

    if script_args.apply_notching:
        pipe_apply_notching = bool(distutils.util.strtobool(script_args.apply_notching))
    else:
        pipe_apply_notching = True

    if script_args.model_prior_file == 'None': 
        dict_loaded = {"omega_ref":bilby.core.prior.LogUniform(1e-11,1e-8,"$\\Omega_{ref}$"),"alpha":bilby.core.prior.Uniform(-4, 4,'$\\alpha$')}
    else:
        dict_loaded = json.load(open(script_args.model_prior_file), object_hook=bilby.core.utils.decode_bilby_json)

    if script_args.non_prior_args == 'None':
        extra_kwargs = {"fref": 25}
    else:
        extra_kwargs = json.loads(script_args.non_prior_args)

    if script_args.outdir == 'pe_output':
        if not os.path.isdir('pe_output'):
            os.mkdir('./pe_output')

    all_parameters_dict = dict_loaded.copy()
    all_parameters_dict.update(extra_kwargs)

    for key in all_parameters_dict.keys():
        if not key in dictionary_parameters[script_args.model]:
            raise AttributeError(
                f"The parameter {key} you provided is not associated with the given model. \n These are the required parameters: {dictionary_parameters[script_args.model]}."
            )

    if not all(key in all_parameters_dict.keys() for key in dictionary_parameters[script_args.model]):
        raise ValueError(
            f"Not all required parameters of the model are present in one of the parameters dictionaries. \n These are the required parameters: {dictionary_parameters[script_args.model]}."
        )

    if script_args.input_file.endswith(("npz")) or script_args.input_file.endswith((".h5")):

        if script_args.input_file.endswith(("npz")):
            frequencies, point_estimate_spectrum, sigma_spectrum, freq_mask = load_npz(script_args.input_file)
        else:
            frequencies, point_estimate_spectrum, sigma_spectrum, freq_mask = load_hdf5(script_args.input_file)

        point_estimates = [point_estimate_spectrum]
        sigmas = [sigma_spectrum]

        if script_args.ifos == "None":
            ifos = ['H1', 'L1']
        else:
            ifos = json.loads(script_args.ifos)

        ifo_1 = Interferometer.get_empty_interferometer(ifos[0])
        ifo_2 = Interferometer.get_empty_interferometer(ifos[1])

        base_12 = Baseline.from_interferometers([ifo_1, ifo_2])

        if script_args.calibration_epsilon is not None:
            base_12.calibration_epsilon = script_args.calibration_epsilon

        if not base_12._orf_polarization_set:
            base_12.orf_polarization = "tensor"

        base_12.frequencies = frequencies

        base_12.point_estimate_spectrum = point_estimate_spectrum
        base_12.sigma_spectrum = sigma_spectrum

    elif script_args.input_file.endswith((".p", ".pickle")):

        base_12 = Baseline.load_from_pickle(script_args.input_file)
        base_12.point_estimate_spectrum = np.real(base_12.point_estimate_spectrum)

    else:
        raise TypeError(
            "The provided file format is currently not supported, try an npz, hdf5 file or a pickled baseline instead."
        )

    if pipe_apply_notching:
        if script_args.notch_list_path:
            base_12.set_frequency_mask(notch_list_path = Path(script_args.notch_list_path))
        else:
            base_12.frequency_mask = freq_mask

    kwargs = {"baselines":[base_12], "model_name":script_args.model}
    kwargs.update(extra_kwargs)
    model = class_model(**kwargs)
    priors = dict_loaded
    if script_args.injection_parameters == "None":
        injection_parameters = dict(omega_ref=10**(-8.67985), alpha=2/3)
    else:
        injection_parameters = json.loads(script_args.injection_parameters)
        for par in injection_parameters:
            injection_parameters[par] = float(Fraction(injection_parameters[par]))

    range_list = [(0,0) for i in range(len(priors))]
    for index,key in enumerate(priors.keys()):
        minimum_range = priors[key].minimum
        if minimum_range==-np.inf:
            minimum_range=-4
        maximum_range = priors[key].maximum
        if maximum_range==np.inf:
            maximum_range=4
        tuple_range = (minimum_range, maximum_range)
        range_list[index] = tuple_range

    hlv=bilby.run_sampler(likelihood=model,priors=priors,sampler='dynesty', npoints=1000, walks=10, maxmcmc=10000, 
    outdir=script_args.outdir,label= 'hlv', injection_parameters = injection_parameters)

    if script_args.quantiles == "None":
        quantiles = [0.05, 0.95]
    else:
        quantiles = json.loads(script_args.quantiles)

    hlv.plot_corner(show_titles=True, title_fmt='.3g', use_math_text=True, quantiles=quantiles, range=range_list)


if __name__ == "__main__":
    main()
