#!/bin/env python

import argparse
import json
from pathlib import Path

import bilby
import numpy as np
import pycbc.psd
from pycbc.psd.analytical import (
    aLIGOAPlusDesignSensitivityT1800042,
    aLIGODesignSensitivityP1200087,
)

from pygwb.detector import Interferometer
from pygwb.network import Network
from pygwb.simulator import default_waveform_generator_arguments


def main():
    simulate_parser = argparse.ArgumentParser()
    simulate_parser.add_argument(
        "--duration", "-d",  help="Duration of each data segment simulated in seconds.", action="store", type=int, required=False, default=64)
    simulate_parser.add_argument(
        "--start_time", "-ts",  help="Start time of the observation in seconds.", action="store", type=float, default=0.0)
    simulate_parser.add_argument(
        "--observing_time", "-Tobs",  help="Duration of the observation in seconds.", action="store", type=int, required=True)
    simulate_parser.add_argument(
        "--sampling_frequency", "-fs",  help="Sampling frequency of the data in Hz.", action="store", type=int, default=2048)
    simulate_parser.add_argument(
        "--injection_file", "-if", help="Bilby injection json dictionary.", action="store", type=Path, required=True
    )
    simulate_parser.add_argument(
        "--detectors", "-det", help="Detectors to simulate data for.", action="store", type=str, required=False, nargs='+', default=['H1', 'L1']
    )
    simulate_parser.add_argument(
        "--sensitivity", "-sn", help="Sensitivity of the detectors. You can find all possible sensitivities at https://pycbc.org/pycbc/latest/html/pycbc.psd.html .", action="store", type=str, required=False, nargs='+', default=['None']
    )
    simulate_parser.add_argument(
        "--outdir", "-od", help="Output path.", action="store", type=Path, required=False
    )
    simulate_parser.add_argument(
        "--waveform_duration", "-wd",  help="Duration to use for waveform generation.", action="store", type=int, required=False, default=None)
    simulate_parser.add_argument(
        "--waveform_approximant", "-wa",  help="Waveform approximation to use for waveform generation.", action="store", type=str, required=False, default=None)
    simulate_parser.add_argument(
        "--waveform_reference_frequency", "-wrf",  help="Waveform reference_frequency to use for waveform generation.", action="store", type=float, required=False, default=None)
    simulate_parser.add_argument(
        "--channel_name", "-cn",  help="Channel name with which data are saved.", action="store", type=str, required=False, default=None)
    simulate_parser.add_argument(
        "--save_file_format", "-sff",  help="File format in which to save the data.", action="store", type=str, required=False, default="gwf")
    
    simulate_args = simulate_parser.parse_args()
    if not simulate_args.outdir:
        simulate_args.outdir = Path("./")

    N_segs = int(simulate_args.observing_time/simulate_args.duration)  # number of data segments to generate

    waveform_generator_arguments = default_waveform_generator_arguments # load waveform parameters
    if simulate_args.waveform_duration is not None:
        waveform_generator_arguments['duration']= simulate_args.waveform_duration
    if simulate_args.waveform_approximant is not None:
        waveform_generator_arguments['waveform_arguments']['waveform_approximant']= simulate_args.waveform_approximant
    if simulate_args.waveform_reference_frequency is not None:
        waveform_generator_arguments['waveform_arguments']['reference_frequency']= simulate_args.waveform_reference_frequency

    with open(simulate_args.injection_file, "r") as file:
        injections = json.load(file)['injections']['content'] #, cls=bilby.core.result.BilbyJsonEncoder)

    # load detectors
    ifo_list = []
    for ifo in simulate_args.detectors:
        try:
            ifo_list.append(Interferometer.get_empty_interferometer(ifo))
        except ValueError:
            raise ValueError(f"Requested detector {ifo} not supported by Bilby.")

    for ifo in ifo_list:
        ifo.start_time = simulate_args.start_time
        ifo.duration = simulate_args.duration
        ifo.sampling_frequency = simulate_args.sampling_frequency

    # load noise arrays
    frequencies = ifo_list[0].frequency_array
    df = frequencies[1] - frequencies[0]
    if len(simulate_args.sensitivity)==1:
        sens_list = []
        for idx in range(len(ifo_list)):
            sens_list.append(simulate_args.sensitivity[0])
    else:
        sens_list = simulate_args.sensitivity

    noise_list = []
    for sens in sens_list:
        if sens=='None':
            noise_array = np.zeros_like(frequencies)
        else:
            if sens == 'Aplus':
                sens = 'aLIGOAPlusDesignSensitivityT1800042'
            try:
                noise_array = np.array(pycbc.psd.analytical.from_string(sens, frequencies.size, df, frequencies[0]))
            except ValueError:
                raise ValueError(f"Sensitivity {sens} is not supported.")
            noise_array[noise_array<1.e-100]=1.e-41
        noise_list.append(noise_array)

    for ifo, noise in zip(ifo_list, noise_list):
        ifo.power_spectral_density = bilby.gw.detector.PowerSpectralDensity(frequencies, noise)

    network_name = ''
    for name in simulate_args.detectors:
        network_name+=name
    net_sim = Network(network_name, ifo_list)

    net_sim.set_interferometer_data_from_simulator(N_segs, CBC_dict=injections, sampling_frequency = simulate_args.sampling_frequency, start_time=simulate_args.start_time, waveform_generator_arguments=waveform_generator_arguments)

    net_sim.save_interferometer_data_to_file(channel_name=simulate_args.channel_name, save_dir=simulate_args.outdir, file_format=simulate_args.save_file_format)
    exit()

if __name__ == "__main__":
    main()
