#!/usr/bin/env python3
"""
Gridded likelihood computation module
Rachel Gray
"""

import gwcosmo
import numpy as np
import json
import pickle
import bilby
import h5py
from itertools import product
from tqdm import tqdm

from gwcosmo.utilities.posterior_utilities import str2bool
from gwcosmo.utilities.arguments import create_parser
from gwcosmo.utilities.cosmology import *
from gwcosmo.utilities.check_boundary import *
from gwcosmo.injections import injections_at_detector

from gwcosmo.prior.priors import *

import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
matplotlib.rcParams['font.family']= 'Times New Roman'
matplotlib.rcParams['font.sans-serif']= ['Bitstream Vera Sans']
matplotlib.rcParams['mathtext.fontset']= 'stixsans'

import seaborn as sns
sns.set_context('paper')
sns.set_style('ticks')
sns.set_palette('colorblind')


parser = create_parser("--method", "--posterior_samples", "--posterior_samples_field", "--skymap", "--redshift_evolution", "--reweight_posterior_samples", "--sky_area", "--min_pixels",  "--outputfile", "--LOS_catalog", "--parameter_dict", "--plot", "--injections_path", "--mass_model", "--snr", "--gravity_model", '--sampler', "--nwalkers", "--npool", "--walks", "--nsteps", "--nlive", "--dlogz")

opts = parser.parse_args()

print(opts)

if opts.method is not None:
    method = str(opts.method)
else:
    parser.error("Missing method. Choose 'sampling' or 'gridded'.")

if opts.LOS_catalog is not None:
    LOS_catalog_path = str(opts.LOS_catalog)
else:
    parser.error('Missing LOS_catalog')

if opts.posterior_samples is not None:
    if str(opts.posterior_samples).endswith('.json'):
        with open(str(opts.posterior_samples)) as json_file:
            posterior_samples_dictionary = json.load(json_file)
    else:
        posterior_samples_dictionary = {}
        posterior_samples_dictionary['event'] = str(opts.posterior_samples)
else:
    parser.error('Missing posterior samples')
    
if opts.skymap is not None:
    if str(opts.skymap).endswith('.json'):
        with open(str(opts.skymap)) as json_file:
            skymaps_dictionary = json.load(json_file)
    else:
        skymaps_dictionary = {}
        skymaps_dictionary['event'] = str(opts.skymap)
else:
    parser.error('Missing skymap')
    
if opts.parameter_dict is not None:
    with open(str(opts.parameter_dict)) as json_file:
        parameter_dict = json.load(json_file)
else:
    parser.error('Missing parameters file')

if opts.posterior_samples_field is not None:
    if str(opts.posterior_samples_field).endswith('.json'):
        with open(str(opts.posterior_samples_field)) as json_file:
            posterior_samples_field_dictionary = json.load(json_file)
        for key in posterior_samples_dictionary.keys():
            if key in posterior_samples_field_dictionary.keys():
                pass
            else:
                posterior_samples_field_dictionary[key] = None 
    else:
        posterior_samples_field_dictionary = {}
        posterior_samples_field_dictionary['event'] = str(opts.posterior_samples_field)
else:
    posterior_samples_field_dictionary = {}
    for key in posterior_samples_dictionary.keys():
        posterior_samples_field_dictionary[key] = None 

try :
    injdata = h5py.File(opts.injections_path,'r')
except OSError:
    injdata = h5py.File(opts.injections_path,'r', locking=False)

injections = injections_at_detector(m1d=np.array(injdata['m1d']),
                                    m2d=np.array(injdata['m2d']),
                                    dl=np.array(injdata['dl']),
                                    prior_vals=np.array(injdata['pini']),
                                    snr_det=np.array(injdata['snr']),
                                    snr_cut=0,
                                    ifar=np.inf+0*np.array(injdata['m1d']),
                                    ifar_cut=0,
                                    ntotal=np.array(injdata['ntotal']),
                                    Tobs=np.array(injdata['Tobs']))
injdata.close()

mass_model = str(opts.mass_model)
if mass_model == 'BBH-powerlaw':
    mass_priors = BBH_powerlaw()
elif mass_model == 'NSBH-powerlaw':
    mass_priors = NSBH_powerlaw()
elif mass_model == 'BBH-powerlaw-gaussian':
    mass_priors = BBH_powerlaw_gaussian()
elif mass_model == 'NSBH-powerlaw-gaussian':
    mass_priors = NSBH_powerlaw_gaussian()
elif mass_model == 'BBH-broken-powerlaw':
    mass_priors = BBH_broken_powerlaw()
elif mass_model == 'NSBH-broken-powerlaw':
    mass_priors = NSBH_broken_powerlaw()
elif mass_model == 'BNS':
    mass_priors = BNS()
else:
    parser.error('Unrecognized mass model')
print(f'Using the {mass_model} mass model')

gravity_model = str(opts.gravity_model)
if gravity_model == 'GR':
    cosmo = standard_cosmology()
elif gravity_model == 'Xi0_n':
    cosmo = Xi0_n_cosmology()
else:
    parser.error('Unrecognized gravity model')
print(f'Using the {gravity_model} gravity model')

#check_boundary(cosmo, parameter_dict, injections, mass_priors, gravity_model, mass_model)

plot = str2bool(opts.plot)
reweight_posterior_samples = str2bool(opts.reweight_posterior_samples)
outputfile = str(opts.outputfile)
redshift_evolution = str(opts.redshift_evolution)

if redshift_evolution=='PowerLaw':
    ps_z = gwcosmo.utilities.host_galaxy_merger_relations.RedshiftEvolutionPowerLaw()
elif redshift_evolution=='Madau':
    ps_z = gwcosmo.utilities.host_galaxy_merger_relations.RedshiftEvolutionMadau()
elif redshift_evolution=='None':
    ps_z = gwcosmo.utilities.host_galaxy_merger_relations.RedshiftEvolutionConstant()
print(f'Assuming a {redshift_evolution} redshift evolution model')

min_pixels = int(opts.min_pixels)
sky_area = float(opts.sky_area)

me = gwcosmo.likelihood.dark_siren_likelihood.PixelatedGalaxyCatalogMultipleEventLikelihood(posterior_samples_dictionary, \
                skymaps_dictionary, injections, LOS_catalog_path, ps_z, cosmo, mass_priors, min_pixels=min_pixels, sky_area=sky_area, \
                posterior_samples_field=posterior_samples_field_dictionary, network_snr_threshold=opts.snr)

for key in parameter_dict:
    if key not in me.parameters.keys():
        print(f'WARNING: The parameter {key} from your parameter dictionary is not recognised by the likelihood module.')


if method == 'sampling':
    sampler = str(opts.sampler)
    nwalkers = int(opts.nwalkers)
    walks = int(opts.walks)
    nsteps = int(opts.nsteps)
    nlive = int(opts.nlive)
    dlogz = float(opts.dlogz)
    npool = int(opts.npool)
    
    priors = {}

    for key in parameter_dict:
        if isinstance(parameter_dict[key]['value'],list):
            if "label" in parameter_dict[key]:
                label = parameter_dict[key]['label']
            else:
                label = key
            if parameter_dict[key]['prior'] == 'Uniform' or parameter_dict[key]['prior'] == 'uniform':
                priors[key] = bilby.core.prior.Uniform(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
            elif parameter_dict[key]['prior'] == 'Gaussian':
                priors[key] = bilby.core.prior.Gaussian(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],name=key,latex_label=label)
            else:
                raise ValueError(f"Unrecognised prior settings for parameter {key} (specify 'Uniform' or 'Gaussian'.")
        else:
            priors[key] = float(parameter_dict[key]['value'])

    sampling_method = "acceptance-walk" # can be "rwalk", "act-walk" (the default). "acceptance-walk" is faster (as of 20230717), see https://lscsoft.docs.ligo.org/bilby/dynesty-guide.html

    bilbyresult = bilby.run_sampler(
    likelihood=me,
    priors=priors,
    outdir=f"{opts.outputfile}",     # the directory to output the results to
    label=f"{opts.outputfile}",      # a label to apply to the output file
    plot=plot,         # by default this is True and a corner plot will be produced
    sampler=sampler,  # set the name of the sampler
    nlive=nlive,         # add in any keyword arguments used by that sampler
    dlogz=dlogz,
    nwalkers=nwalkers,         # add in any keyword arguments used by that sampler
    walks=walks,
    nsteps=nsteps,
    npool=npool,
    sample=sampling_method
    #injection_parameters={"m": m, "c": c},  # (optional) true values for adding to plots
    )


elif method == 'gridded':

    parameter_grid = {}
    fixed_params = {}
    parameter_grid_indices = {}

    for key in parameter_dict:
        if isinstance(parameter_dict[key]['value'],list):
            if parameter_dict[key]['prior'] == 'Uniform' or parameter_dict[key]['prior'] == 'uniform':
                print(f"Setting a uniform prior on {key} in the range [{parameter_dict[key]['value'][0]}, {parameter_dict[key]['value'][1]}]")
            else:
                raise ValueError(f"Unrecognised prior settings for parameter {key} (specify 'Uniform').")
            if len(parameter_dict[key]['value'])==3:
                parameter_grid[key] = np.linspace(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],parameter_dict[key]['value'][2])
            elif len(parameter_dict[key]['value'])==2:
                parameter_grid[key] = np.linspace(parameter_dict[key]['value'][0],parameter_dict[key]['value'][1],10)
            else:
                raise ValueError(f"Incorrect formatting for parameter {key}.")
            parameter_grid_indices[key] = range(len(parameter_grid[key]))
        else:
            fixed_params[key] = float(parameter_dict[key]['value'])

    for key in fixed_params:
        me.parameters[key] = fixed_params[key]
        print(f'Setting parameter {key} to {fixed_params[key]}')

    shape = [len(x) for x in parameter_grid.values()]
    likelihood = np.zeros(shape)

    print(f'Estimated runtime for this set of parameters and events is {round(0.25*likelihood.size*len(posterior_samples_dictionary),1)} seconds, which is {round(0.25*likelihood.size*len(posterior_samples_dictionary)/60,1)} minutes or {round(0.25*likelihood.size*len(posterior_samples_dictionary)/(60*60),1)} hours.')

    names = list(parameter_grid.keys())
    values = []
    for items in product(*parameter_grid.values()):
        values.append(items)
    indices = []
    for items in product(*parameter_grid_indices.values()):
        indices.append(items)

    print(f'Computing the likelihood on a grid over {names}')
    for i, value in enumerate(tqdm(values)):
        n=0
        for name in names:
            me.parameters[name] = value[n]
            n += 1
        likelihood[indices[i]] = me.log_likelihood()

    # rescale values of the log-likelihood before exponentiating
    likelihood -= np.nanmax(likelihood[likelihood < np.inf])
    likelihood = np.exp(likelihood)

    param_values = [parameter_grid[k] for k in names]

    mylist = np.array([names,param_values,likelihood,opts,parameter_dict],dtype=object)
    np.savez(outputfile+'.npz',mylist)

    if plot:
        no_params = len(names)
        labels=[]
        for name in names:
            if "label" in parameter_dict[name]:
                labels.append(parameter_dict[name]["label"])
            else:
                labels.append(name)

        if no_params == 1:
            ind_lik_norm = likelihood/np.sum(likelihood)/(param_values[0][1]-param_values[0][0])
            #print('plot_check', likelihood, np.sum(likelihood),(param_values[0][1]-param_values[0][0] ))
            plt.figure(figsize=[4.2,4.2])
            plt.plot(param_values[0],ind_lik_norm)
            plt.xlabel(labels[0],fontsize=16)
            plt.xlim(param_values[0][0],param_values[0][-1])
            plt.ylim(0,1.1*np.max(ind_lik_norm))
            plt.ylabel(f'p({labels[0]})',fontsize=16)
            
        else:
            fig, ax = plt.subplots(no_params, no_params, figsize=[4*no_params,4*no_params],constrained_layout=True)
            for column in np.arange(0,no_params):
                for row in np.arange(0,no_params):
                    indices = list(range(no_params))
                    if column > row:
                        fig.delaxes(ax[row][column])

                    elif row == column:
                        indices.remove(row)
                        ind_lik = np.sum(likelihood,axis=tuple(indices))
                        ind_lik_norm = ind_lik/np.sum(ind_lik)/(param_values[row][1]-param_values[row][0])

                        ax[row,column].plot(param_values[row],ind_lik_norm)
                        ax[row,column].set_xlim(param_values[row][0],param_values[row][-1])
                        ax[row,column].set_ylim(0,1.1*np.max(ind_lik_norm))

                    else:
                        indices.remove(row)
                        indices.remove(column)
                        ax[row,column].contourf(param_values[column],param_values[row],np.sum(likelihood,axis=tuple(indices)).T,20)

                    if column == 0:
                        ax[row,column].set_ylabel(labels[row], fontsize=16)
                    if row == no_params-1:
                        ax[row,column].set_xlabel(labels[column],fontsize=16)

        plt.savefig(f'{outputfile}.png',dpi=100,bbox_inches='tight')



