#!/usr/bin/env python3
"""
Multi-paraeter likelihood computation module for GW event with unique EM counterpart

Tathagata Ghosh
"""

import numpy as np
import gwcosmo
import json
import pickle
import bilby
from itertools import product
from gwcosmo.utilities.posterior_utilities import str2bool
from gwcosmo.utilities.arguments import create_parser
from gwcosmo.utilities.cosmology import *
from gwcosmo.utilities.check_boundary import *
from gwcosmo.utilities.host_galaxy_merger_relations import RedshiftEvolutionConstant, RedshiftEvolutionPowerLaw, RedshiftEvolutionMadau
from gwcosmo.injections import injections_at_detector
import astropy.constants as const
from tqdm import tqdm
import h5py

from gwcosmo.prior.priors import *

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
matplotlib.rcParams['font.family']= 'Times New Roman'
matplotlib.rcParams['font.sans-serif']= ['Bitstream Vera Sans']
matplotlib.rcParams['mathtext.fontset']= 'stixsans'

import seaborn as sns
sns.set_context('paper')
sns.set_style('ticks')
sns.set_palette('colorblind')


speed_of_light = const.c.to('km/s').value


parser = create_parser("--method", "--counterpart_dictionary", "--counterpart_z", "--counterpart_sigmaz", "--counterpart_v", "--counterpart_sigmav", "--counterpart_ra", "--counterpart_dec", "--posterior_samples", "--skymap", "--skymap_prior_distance", "--skymap_H0", "--skymap_Omega_m", "--post_los", "--redshift_evolution", "--parameter_dict", "--injections_path", "--mass_model", "--snr", "--plot", "--outputfile", "--posterior_samples_field", "--gravity_model", '--sampler', "--nwalkers", "--npool", "--walks", "--nsteps", "--nlive", "--dlogz", "--nsamps")

opts = parser.parse_args()

if opts.method is not None:
    method = str(opts.method)
else:
    parser.error("Missing method. Choose 'sampling' or 'gridded'.")

if opts.counterpart_dictionary is None :
    if (opts.counterpart_z is None and opts.counterpart_v is None) :
        parser.error("Provide either counterpart redshift or recessional velocity.")
    if (opts.counterpart_sigmaz is None and opts.counterpart_sigmav is None):
        parser.error("Provide either counterpart redshift uncertainty or recessional velocity uncertainty.")
if opts.posterior_samples is None and opts.skymap is None:
    parser.error ("posterior samples and skymap both are missing")

if (opts.counterpart_z is not None and opts.counterpart_v is not None):
    print("Both counterpart redshift and recessional velocity provided. Using recessional velocity.")
if (opts.counterpart_sigmaz is not None and opts.counterpart_sigmav is not None):
    print("Both counterpart redshift uncertainty and recessional velocity uncertainty provided. Using recessional velocity uncertainty.")
if (opts.posterior_samples is not None and opts.skymap is not None):
    print("Both posterior samples and skymap provided. Using posterior samples.")


# redshift information
if opts.counterpart_dictionary is not None :
    with open(str(opts.counterpart_dictionary)) as json_file:
        counterpart_dictionary = json.load(json_file)
    for event in counterpart_dictionary.keys():
        if "redshift" in counterpart_dictionary [event].keys() :
            pass
        elif "velocity" in counterpart_dictionary [event].keys() :    
            counterpart_dictionary [event] ["redshift"] = counterpart_dictionary [event] ["velocity"]/speed_of_light
        else :
            parser.error (f"provide redshift or velocity for {event}")
else :
    if opts.counterpart_z is not None:
        counterpart_z = float(opts.counterpart_z)
    if opts.counterpart_sigmaz is not None:
        counterpart_sigmaz = float(opts.counterpart_sigmaz)
    if opts.counterpart_v is not None:
        counterpart_v = float(opts.counterpart_v)
        counterpart_z = counterpart_v/speed_of_light
    if opts.counterpart_sigmav is not None:
        counterpart_sigmav = float(opts.counterpart_sigmav)
        counterpart_sigmaz = counterpart_sigmav/speed_of_light

    counterpart_dictionary = {}
    counterpart_dictionary ['event'] = {}
    counterpart_dictionary ['event'] ["redshift"]= np.array ([counterpart_z,counterpart_sigmaz])

# ra dec information
if opts.posterior_samples is None and opts.skymap is not None:
    if opts.counterpart_dictionary is None :
        if opts.counterpart_ra is None or opts.counterpart_dec is None:
            parser.error('Provide ra and/or dec.')
        elif opts.counterpart_ra is not None and opts.counterpart_dec is not None:
            counterpart_ra = float(opts.counterpart_ra)
            counterpart_dec = float(opts.counterpart_dec)
            counterpart_dictionary ['event'] ["ra_dec"]= np.array ([counterpart_ra,counterpart_dec]) 

# Skymap prior information
if opts.posterior_samples is None and opts.skymap is not None:
    if opts.skymap_prior_distance=="dlSquare" :
        print ("Distance prior for GW skymap is uniform dl^2.")
    elif opts.skymap_prior_distance=="Uniform" :
        print ("Distance prior for GW skymap is uniform in luminosity distance.")
    elif opts.skymap_prior_distance=="UniformComoving" :
        print (f"Distance prior for GW skymap is uniform in comoving  volume with H0={opts.skymap_H0} km/s/Mpc and Omega_m={opts.skymap_Omega_m}.") 

# loading GW data (posterior samples/skymap)
if opts.posterior_samples is not None:
    if str(opts.posterior_samples).endswith('.json'):
        with open(str(opts.posterior_samples)) as json_file:
            posterior_samples_dictionary = json.load(json_file)
    else :		
        posterior_samples_dictionary = {}
        posterior_samples_dictionary['event'] = str(opts.posterior_samples)
        #if str2bool(opts.post_los_dictionary) is False: 
        #    if opts.counterpart_dictionary is None :
        #        if opts.counterpart_ra is None or opts.counterpart_dec is None:
        #            parser.error('Provide ra and/or dec.')
        #        elif opts.counterpart_ra is not None and opts.counterpart_dec is not None:
        #            counterpart_ra = float(opts.counterpart_ra)
        #            counterpart_dec = float(opts.counterpart_dec)
        #            counterpart_dictionary ['event'] ["ra_dec"]= np.array ([counterpart_ra,counterpart_dec])
    post_los_dictionary = {}
    nsamps_dictionary = {}
    if opts.post_los is None:
        print ("Assuming posterior samples are conditioned over line of sight.") 
        for event in posterior_samples_dictionary.keys() :
            post_los_dictionary [event] = True
    elif str(opts.post_los).endswith('.json'): 
        with open(str(opts.post_los)) as json_file:
            post_los_dictionary_in = json.load(json_file)
        for event in post_los_dictionary_in :
            post_los_dictionary [event] = str2bool (post_los_dictionary_in [event]["post_los"])
            if "nsamps" in post_los_dictionary_in [event].keys():
                nsamps_dictionary [event] = int(post_los_dictionary_in [event]["nsamps"])
            elif isinstance(opts.nsamps, int) or isinstance(opts.nsamps, float):
                nsamps_dictionary [event] = int(opts.nsamps)
    else :
        post_los = str2bool(opts.post_los)
        for event in posterior_samples_dictionary.keys() :
            post_los_dictionary [event] = post_los
        if post_los is False and (isinstance(opts.nsamps, int) or isinstance(opts.nsamps, float)):
            for event in posterior_samples_dictionary.keys():
                nsamps_dictionary [event] = int(opts.nsamps)

    events = list(post_los_dictionary.keys())
    nevents = len(events)
    if nevents >1 :
        for event in events :
            if post_los_dictionary[event] is False :
                if counterpart_dictionary[event].get("ra_dec") is None :
                    parser.error(f"Provide ra and dec infromation in counterpart dictionary for {event}")
    else :
        event = events [0]
        if post_los_dictionary[event] is False :
            if counterpart_dictionary[event].get("ra_dec") is None :
                if opts.counterpart_ra is None or opts.counterpart_dec is None:
                    parser.error('Provide ra and/or dec using arguments counterpart_ra and counterpart_dec.')
                elif opts.counterpart_ra is not None and opts.counterpart_dec is not None:
                    counterpart_ra = float(opts.counterpart_ra)
                    counterpart_dec = float(opts.counterpart_dec)
                    counterpart_dictionary [event] ["ra_dec"]= np.array ([counterpart_ra,counterpart_dec])


elif opts.posterior_samples is None and opts.skymap is not None:
    if str(opts.skymap).endswith('.json'):
        with open(str(opts.skymap)) as json_file:
            skymap_dictionary = json.load(json_file)
    else :
        skymap_dictionary = {}
        skymap_dictionary['event'] = str(opts.skymap)

# posterior sample field
if opts.posterior_samples is not None:
    if opts.posterior_samples_field is not None:
        if str(opts.posterior_samples_field).endswith('.json'):
            with open(str(opts.posterior_samples_field)) as json_file:
                posterior_samples_field_dictionary = json.load(json_file)
            for key in posterior_samples_dictionary.keys():
                if key in posterior_samples_field_dictionary.keys():
                    pass
                else:
                    posterior_samples_field_dictionary[key] = None 
        else:
            posterior_samples_field_dictionary = {}
            posterior_samples_field_dictionary['event'] = str(opts.posterior_samples_field)
            posterior_samples_field_dictionary [event] = None
    else:
        posterior_samples_field_dictionary = {}
        for key in posterior_samples_dictionary.keys():
            posterior_samples_field_dictionary[key] = None 

# parameter(s) file
if opts.parameter_dict is not None:
    with open(str(opts.parameter_dict)) as json_file:
        parameter_dict = json.load(json_file)
    if opts.posterior_samples is None and opts.skymap is not None :
        mass_parameters = ['alpha', 'delta_m', 'mu_g', 'sigma_g', 'lambda_peak', 'alpha_1', 'alpha_2', 'b', 'mminbh', 'mmaxbh', 'beta', 'alphans', 'mminns', 'mmaxns']
        for key in mass_parameters :
            if key in parameter_dict.keys() :
                if isinstance(parameter_dict[key]['value'], list):
                    parser.error(f"Mass parameter({key}) cannot be inferred while using GW skymap.")
else:
    parser.error('Missing parameters file')

try :
    injdata = h5py.File(opts.injections_path,'r')
except OSError:
    injdata = h5py.File(opts.injections_path,'r', locking=False)
        
injections = injections_at_detector(m1d=np.array(injdata['m1d']),
                                    m2d=np.array(injdata['m2d']),
                                    dl=np.array(injdata['dl']),
                                    prior_vals=np.array(injdata['pini']),
                                    snr_det=np.array(injdata['snr']),
                                    snr_cut=0,
                                    ifar=np.inf+0*np.array(injdata['m1d']),
                                    ifar_cut=0,
                                    ntotal=np.array(injdata['ntotal']),
                                    Tobs=np.array(injdata['Tobs']))

injdata.close()

mass_model = str(opts.mass_model)
if mass_model == 'BBH-powerlaw':
    mass_priors = BBH_powerlaw()
elif mass_model == 'NSBH-powerlaw':
    mass_priors = NSBH_powerlaw()
elif mass_model == 'BBH-powerlaw-gaussian':
    mass_priors = BBH_powerlaw_gaussian()
elif mass_model == 'NSBH-powerlaw-gaussian':
    mass_priors = NSBH_powerlaw_gaussian()
elif mass_model == 'BBH-broken-powerlaw':
    mass_priors = BBH_broken_powerlaw()
elif mass_model == 'NSBH-broken-powerlaw':
    mass_priors = NSBH_broken_powerlaw()
elif mass_model == 'BNS':
    mass_priors = BNS()
else:
    parser.error('Unrecognized mass model')
print(f'Using the {mass_model} mass model')

gravity_model = str(opts.gravity_model)
if gravity_model == 'GR':
    cosmo = standard_cosmology()
elif gravity_model == 'Xi0_n':
    cosmo = Xi0_n_cosmology()
else:
    parser.error('Unrecognized gravity model')
print(f'Using the {gravity_model} gravity model')

#check_boundary(cosmo, parameter_dict, injections, mass_priors, gravity_model, mass_model)

plot = str2bool(opts.plot) 
outputfile = str(opts.outputfile)

#redshift evolution model
redshift_evolution = str(opts.redshift_evolution)

if redshift_evolution=='PowerLaw':
    ps_z = RedshiftEvolutionPowerLaw()
elif redshift_evolution=='Madau':
    ps_z = RedshiftEvolutionMadau()
elif redshift_evolution=='None':
    ps_z = RedshiftEvolutionConstant()
print(f'Assuming a {redshift_evolution} redshift evolution model')



if opts.posterior_samples is not None: 
    #if str2bool(opts.post_los) is True: 
    me = gwcosmo.likelihood.bright_siren_likelihood.MultipleEventLikelihoodEM (
        counterpart_dictionary, injections, ps_z, cosmo, mass_priors, 
        posterior_samples_dictionary=posterior_samples_dictionary,
        posterior_samples_field=posterior_samples_field_dictionary,
        network_snr_threshold=opts.snr, post_los_dictionary=post_los_dictionary, nsamps=nsamps_dictionary)
    #else:
    #    nsamps = int(opts.nsamps)
    #    me = gwcosmo.likelihood.bright_siren_likelihood.MultipleEventLikelihoodEM (
    #        counterpart_dictionary, injections, ps_z, cosmo, mass_priors, 
    #        posterior_samples_dictionary=posterior_samples_dictionary,
    #        posterior_samples_field=posterior_samples_field_dictionary,
    #        network_snr_threshold=opts.snr, post_los_dictionary=post_los_dictionary, nsamps=nsamps)

elif opts.posterior_samples is None and opts.skymap is not None:
    me = gwcosmo.likelihood.bright_siren_likelihood.MultipleEventLikelihoodEM (
         counterpart_dictionary, injections, ps_z, cosmo, mass_priors, 
         skymap_dictionary=skymap_dictionary,  
         network_snr_threshold=opts.snr, skymap_prior_distance=opts.skymap_prior_distance, 
         skymap_H0=float(opts.skymap_H0), skymap_Omega_m=float(opts.skymap_Omega_m))
         
for key in parameter_dict:
    if key not in me.parameters.keys():
        print(f'WARNING: The parameter {key} from your parameter dictionary is not recognised by the likelihood module.')
        
if method == "sampling" :   

    sampler = str(opts.sampler)
    nwalkers = int(opts.nwalkers)
    walks = int(opts.walks)
    nsteps = int(opts.nsteps)
    nlive = int(opts.nlive)
    dlogz = float(opts.dlogz)
    npool = int(opts.npool)

    priors = {}

    for key in parameter_dict:
        if isinstance(parameter_dict[key]['value'],list):
            if "label" in parameter_dict[key]:
                label = parameter_dict[key]['label']
            else:
                label = key
            if parameter_dict[key]['prior'] == 'Uniform' or parameter_dict[key]['prior'] == 'uniform':
                priors[key] = bilby.core.prior.Uniform(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
            elif parameter_dict[key]['prior'] == 'Gaussian':
                priors[key] = bilby.core.prior.Gaussian(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
            else:
                raise ValueError(f"Unrecognised prior settings for parameter {key} (specify 'Uniform' or 'Gaussian'.")
        else:
            priors[key] = float(parameter_dict[key]['value'])

    sampling_method = "acceptance-walk" # can be "rwalk", "act-walk" (the default). "acceptance-walk" is faster (as of 20230717), see         https://lscsoft.docs.ligo.org/bilby/dynesty-guide.html
    
    bilbyresult = bilby.run_sampler(
    likelihood=me,
    priors=priors,
    outdir=f"{opts.outputfile}",     # the directory to output the results to
    label=f"{opts.outputfile}",      # a label to apply to the output file
    plot=plot,         # by default this is True and a corner plot will be produced
    sampler=sampler,  # set the name of the sampler
    nlive=nlive,         # add in any keyword arguments used by that sampler
    dlogz=dlogz,
    nwalkers=nwalkers,   # add in any keyword arguments used by that sampler
    walks=walks,
    nsteps=nsteps,
    npool=npool,
    sample=sampling_method
    #injection_parameters={"m": m, "c": c},  # (optional) true values for adding to plots
    )


elif method == "gridded" :
         
    parameter_grid = {}
    fixed_params = {}
    parameter_grid_indices = {}

    for key in parameter_dict:
        if isinstance(parameter_dict[key]['value'],list):
            if parameter_dict[key]['prior'] == 'Uniform' or parameter_dict[key]['prior'] == 'uniform':
                print(f"Setting a uniform prior on {key} in the range [{parameter_dict[key]['value'][0]}, {parameter_dict[key]['value'][1]}]")
            else:
                raise ValueError(f"Unrecognised prior settings for parameter {key} (specify 'Uniform').")
            if len(parameter_dict[key]['value'])==3:
                parameter_grid[key] = np.linspace(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],parameter_dict[key]['value'][2])
            elif len(parameter_dict[key]['value'])==2:
                parameter_grid[key] = np.linspace(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],10)
            else:
                raise ValueError(f"Incorrect formatting for parameter {key}.")
            parameter_grid_indices[key] = range(len(parameter_grid[key]))
        else:
            fixed_params[key] = float(parameter_dict[key]['value'])

    for key in fixed_params:
        me.parameters[key] = fixed_params[key]
        print(f'Setting parameter {key} to {fixed_params[key]}')

    shape = [len(x) for x in parameter_grid.values()]
    likelihood = np.zeros(shape)

    names = list(parameter_grid.keys())
    values = []
    for items in product(*parameter_grid.values()):
        values.append(items)
    indices = []
    for items in product(*parameter_grid_indices.values()):
        indices.append(items)
	
    print(f'Computing the likelihood on a grid over {names}')
    for i, value in enumerate(tqdm(values)):
        n=0
        for name in names:
            me.parameters[name] = value[n]
            n += 1
        likelihood[indices[i]] = me.log_likelihood()

    # rescale values of the log-likelihood before exponentiating
    likelihood -= np.nanmax(likelihood) #np.nanmin(likelihood[likelihood != -np.inf])
    likelihood = np.exp(likelihood)
	
    param_values = [parameter_grid[k] for k in names]

    mylist = np.array([names,param_values,likelihood,opts,parameter_dict],dtype=object)
    np.savez(outputfile+'.npz',mylist)

    if plot:
        no_params = len(names)
        labels=[]
        for name in names:
            if "label" in parameter_dict[name]:
                labels.append(parameter_dict[name]["label"])
            else:
                labels.append(name)
	
        if no_params == 1:
            ind_lik_norm = likelihood/np.sum(likelihood)/(param_values[0][1]-param_values[0][0])
            plt.figure(figsize=[4.2,4.2])
            plt.plot(param_values[0],ind_lik_norm)
            plt.xlabel(labels[0],fontsize=16)
            plt.xlim(param_values[0][0],param_values[0][-1])
            plt.ylim(0,1.1*np.max(ind_lik_norm))
            plt.ylabel(f'p({labels[0]})',fontsize=16)

        else:
            fig, ax = plt.subplots(no_params, no_params, figsize=[4*no_params,4*no_params],constrained_layout=True)
            for column in np.arange(0,no_params):
                for row in np.arange(0,no_params):
                    indices = list(range(no_params))
                    if column > row:
                        fig.delaxes(ax[row][column])

                    elif row == column:
                        indices.remove(row)
                        ind_lik = np.sum(likelihood,axis=tuple(indices))
                        ind_lik_norm = ind_lik/np.sum(ind_lik)/(param_values[row][1]-param_values[row][0])

                        ax[row,column].plot(param_values[row],ind_lik_norm)
                        ax[row,column].set_xlim(param_values[row][0],param_values[row][-1])
                        ax[row,column].set_ylim(0,1.1*np.max(ind_lik_norm))

                    else:
                        indices.remove(row)
                        indices.remove(column)
                        ax[row,column].contourf(param_values[column],param_values[row],np.sum(likelihood,axis=tuple(indices)).T,20)
                    if column == 0:
                        ax[row,column].set_ylabel(labels[row], fontsize=16)
                    if row == no_params-1:
                        ax[row,column].set_xlabel(labels[column],fontsize=16)

        plt.savefig(f'{outputfile}.png',dpi=100,bbox_inches='tight')


