#!/bin/env python

import argparse
import distutils
import json
import os
import sys
import warnings
from pathlib import Path
from typing import List

import bilby
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger

from pygwb.baseline import Baseline
from pygwb.detector import Interferometer
from pygwb.parameters import Parameters, ParametersHelp


def help_arguments(parent):
    ann = getattr(Parameters, "__annotations__", {})
    parser = argparse.ArgumentParser(parents=[parent])
    for name, dtype in ann.items():
        name_help = ParametersHelp[name].help
        if dtype == List:
            parser.add_argument(f"--{name}", help=name_help, type=str, nargs='+', required=False)
        else:
            parser.add_argument(f"--{name}", help=name_help, type=str, required=False)
    return parser

def main():
    params = Parameters()
    pipe_parser = argparse.ArgumentParser(add_help=False)
    pipe_parser.add_argument(
        "--param_file", help="Parameter file to use for analysis.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--output_path", help="Location to save output to.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--calc_coh", help="Calculate coherence spectrum from data.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--calc_pt_est", help="Calculate omega point estimate and sigma from data.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--apply_dsc", help="Apply delta sigma cut when calculating final output.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--pickle_out", help="Pickle output Baseline of the analysis. Default is False.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--wipe_ifo", help="Wipe interferometer data to reduce size of pickled Baseline.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--file_tag", help="File naming tag. By default, reads in first and last time in dataset.", action="store", type=str, required=False
    )

    help_args = help_arguments(pipe_parser)
    help_args.parse_known_args()  # for help

    script_args, parameters_args = pipe_parser.parse_known_args()

    if script_args.param_file:
        params.update_from_file(script_args.param_file)
    else: 
        warnings.warn("No parameter file was passed - script runinng on script arguments and defaults. Friendly reminder: the parameter file argument is \"param_file\".")

    params.update_from_arguments(parameters_args)

    if params.notch_list_path:
    	notch_list_path = Path(params.notch_list_path)
    	if not notch_list_path.exists():
    		raise ValueError(
    		f"Your path to the notch list {params.notch_list_path} does not exist or cannot be reached."
    		" Please enter a valid path instead or leave it blank to use no notches at all."
    		)

    logger.info(f"Successfully imported parameters from paramfile and input.")

    if script_args.output_path:
        output_path = Path(script_args.output_path)
        if not output_path.exists():
            output_path.mkdir(parents=True)
    else: 
        output_path = Path('')

    if script_args.calc_coh:
        pipe_calculate_coherence = bool(distutils.util.strtobool(script_args.calc_coh))
    else:
        pipe_calculate_coherence = False

    if script_args.calc_pt_est:
        pipe_calculate_point_estimate = bool(distutils.util.strtobool(script_args.calc_pt_est))
    else:
        pipe_calculate_point_estimate = True
    
    if script_args.apply_dsc:
        pipe_apply_dsc = bool(distutils.util.strtobool(script_args.apply_dsc))
    else:
        pipe_apply_dsc = True

    if script_args.pickle_out:
        pipe_pickle_baseline = bool(distutils.util.strtobool(script_args.pickle_out))
    else:
        pipe_pickle_baseline = False

    if script_args.wipe_ifo:
        wipe_ifo = bool(distutils.util.strtobool(script_args.wipe_ifo))
    else:
        wipe_ifo = True

    if not script_args.file_tag:
        script_args.file_tag = f"{int(params.t0)}-{int(params.tf)}"

    if script_args.param_file:
        outfile_path = f"{output_path}/{Path(script_args.param_file).stem}_{script_args.file_tag}_final{Path(script_args.param_file).suffix}"
    else:
        outfile_path = Path(f"{output_path}/parameters_{script_args.file_tag}_final.ini")
    params.save_paramfile(str(outfile_path))
    logger.info(f"Saved final param file at {outfile_path}.")

    param_dict = params.parse_ifo_parameters()
    ifo_list = params.interferometer_list
    ifo_1 = Interferometer.from_parameters(ifo_list[0], param_dict[ifo_list[0]])
    ifo_2 = Interferometer.from_parameters(ifo_list[1], param_dict[ifo_list[1]])
    logger.info(f"Loaded up interferometers with selected parameters.")

    if params.gate_data:
        if params.path_gate_data:
            logger.info("Loading gates from file.")
            params.path_gate_data = Path(params.path_gate_data)
            if not params.path_gate_data.is_file():
                list_of_gatefiles = sorted(params.path_gate_data.glob("point_estimate*.npz"))
                npzobject_list = [path for path in list_of_gatefiles if path.match(f"*{int(params.t0)}-{int(params.tf)}*")]
                npzobject = np.load(params.path_gate_data / npzobject_list[0])
            else:
                npzobject = np.load(params.path_gate_data)
            for index, ifo_obj in enumerate([ifo_1, ifo_2]):
                ifo_obj.apply_gates_from_file(
                    npzobject,
                    index + 1,
                    gate_tpad=param_dict[ifo_list[index]].gate_tpad,
                )
                logger.info(f"Gates loaded and applied for IFO {ifo_list[index]}:{ifo_obj.gates}")

        else:
            logger.info(f"Applying autogating procedure.")
            for ifo, ifo_obj in zip(ifo_list, [ifo_1, ifo_2]):
                gate_params = { 
                        "gate_tzero":param_dict[ifo].gate_tzero,
                        "gate_tpad":param_dict[ifo].gate_tpad,
                        "gate_threshold":param_dict[ifo].gate_threshold,
                        "cluster_window":param_dict[ifo].cluster_window,
                        "gate_whiten":param_dict[ifo].gate_whiten,
                        }
                ifo_obj.gate_data_apply(**gate_params)
                logger.info(f"Gates for IFO {ifo}:{ifo_obj.gates}")

    base_12 = Baseline.from_parameters(ifo_1, ifo_2, params)
    logger.info(
        f"Baseline with interferometers {ifo_1.name}, {ifo_2.name} initialised."
    )

    logger.info(f"Setting PSDs and CSDs of the baseline...")
    base_12.set_cross_and_power_spectral_density(params.frequency_resolution)
    base_12.set_average_power_spectral_densities()
    base_12.set_average_cross_spectral_density()

    logger.info(f"... done.")

    """
    check nothing's gone wrong in the frequency handling...
    """
    deltaF = base_12.frequencies[1] - base_12.frequencies[0]
    if abs(deltaF - params.frequency_resolution) > 1e-6:
        raise ValueError("Frequency resolution in PSD/CSD is different than requested.")

    base_12.calculate_delta_sigma_cut(
        delta_sigma_cut = params.delta_sigma_cut,
        alphas = params.alphas_delta_sigma_cut,
        fref = params.fref,
        flow=params.flow,
        fhigh=params.fhigh,
        return_naive_and_averaged_sigmas = True
    )

    logger.info(
        f"times flagged by the delta sigma cut as badGPStimes:{base_12.badGPStimes}"
    )

    if pipe_calculate_coherence:
        logger.info(f"calculating the coherence...")

        base_12.set_coherence_spectrum(apply_dsc=pipe_apply_dsc, 
            flow=params.flow, fhigh=params.fhigh)

    if pipe_calculate_point_estimate:
        logger.info(f"calculating the point estimate and sigma...")

        base_12.set_point_estimate_sigma(
            alpha=params.alpha,
            fref=params.fref,
            flow=params.flow,
            fhigh=params.fhigh,
            apply_dsc=pipe_apply_dsc
        )

        logger.success(
            f"Ran stochastic search over times {script_args.file_tag}"
        )
        logger.success(f"\tPOINT ESTIMATE: {base_12.point_estimate:e}")
        logger.success(f"\tSIGMA: {base_12.sigma:e}")

        data_file_name = f"{output_path}/point_estimate_sigma_{script_args.file_tag}"

        logger.info(
            "Saving point_estimate and sigma spectrograms, spectra, and final values to file."
        )
        logger.info("Saving average psds and csd to file.")
        base_12.save_point_estimate_spectra(
            params.save_data_type,
            data_file_name,
        )
        csd_file_name = f"{output_path}/psds_csds_{script_args.file_tag}"
        base_12.save_psds_csds(
            params.save_data_type,
            csd_file_name,
        )
        if pipe_pickle_baseline:
            logger.info("Pickling the baseline.")
            pickle_name = f"{output_path}/{base_12.name}_{script_args.file_tag}.pickle"
            base_12.save_to_pickle(pickle_name, wipe=wipe_ifo)

    else:
        logger.info("Saving average psds and csd to file.")

        data_file_name = f"{output_path}/psds_csds_{script_args.file_tag}"

        base_12.npz_save_psds_csds(
            data_file_name,
            base_12.frequencies,
            base_12.csd,
            base_12.interferometer_1.average_psd,
            base_12.interferometer_2.average_psd,
        )

if __name__ == "__main__":
    main()
