#! /usr/bin/env python

import argparse
import json
import logging
import os
from collections import defaultdict
from functools import lru_cache
from itertools import combinations
from pathlib import Path
from time import perf_counter as timer

from obspy import read, read_inventory
from obspy.taup import TauPyModel
from obspy.taup.taup_create import build_taup_model
from tqdm import tqdm

import pyres as pr

parser = argparse.ArgumentParser(prog=__file__)
parser.add_argument('catalogue', type=Path)
parser.add_argument('inventory', type=Path)
parser.add_argument('parameters', type=Path)
parser.add_argument('output_dir', type=Path, default=Path())
parser.add_argument('--phase_file', type=Path, default=None,
                    help="Catalogue containing the phase picking and other information")
parser.add_argument('--phase_type', choices=['hypoinv', 'nll', 'quakeml', 'hypoel'], default=None,
                    help="Type of PHASE_FILE")
parser.add_argument('--taup_model', type=Path, default=None)
parser.add_argument('--rebuild_model', default=False, action='store_true')
parser.add_argument('--graphics_dir', default=None, type=Path)
parser.add_argument('--graphics_format', default='pdf')
parser.add_argument('--stop', default=False, action='store_true')
parser.add_argument('--log', help="Log level", default='info')
parser.add_argument('--progress', help="Show progress bar", default=False, action='store_true')


cli_args = parser.parse_args()

if cli_args.phase_file is None and cli_args.taup_model is None:
    parser.error("Either a phase file or a model must be provided")

if cli_args.phase_file is not None and cli_args.phase_type is None:
    parser.error("You must provide a type for the given phase file")

logging.basicConfig(format='%(levelname)s-%(asctime)s: %(message)s',
                    level=getattr(logging, cli_args.log.upper()))

if __name__ == '__main__':
    tic = timer()

    with cli_args.inventory.open('r') as file:
        inventory = read_inventory(file)
    with cli_args.parameters.open('r') as file:
        parameters = json.load(file)
    catalogue = pr.read_zmap(cli_args.catalogue, extensions=['name', 'path'])
    errors = pr.read_errors(catalogue, cli_args.phase_file, cli_args.phase_type)

    if cli_args.taup_model:
        if cli_args.rebuild_model:
            build_taup_model(str(cli_args.taup_model.with_suffix('.tvel')))
        try:
            get_travel_times = lru_cache()(TauPyModel(model=str(cli_args.taup_model)).get_travel_times)
        except FileNotFoundError:
            build_taup_model(str(cli_args.taup_model.with_suffix('.tvel')))
            get_travel_times = lru_cache()(TauPyModel(model=str(cli_args.taup_model)).get_travel_times)
    else:
        get_travel_times = None

    REs = defaultdict(dict)
    for event_1, event_2 in tqdm(combinations(catalogue.itertuples(), 2),
                                 total=len(catalogue) * (len(catalogue) - 1) // 2,
                                 disable=not cli_args.progress):
        try:
            logging.debug(f"Analyzing pair {event_1.name}, {event_2.name}")
            graphics_dir = cli_args.output_dir / f"{cli_args.graphics_dir}/{event_1.name}_{event_2.name}/"
            freq_range, delta_sp_threshold, min_stations = pr.piecewise_from_thresholds(min(event_1.magnitude,
                                                                                            event_2.magnitude),
                                                                                        parameters['thresholds'])

            for trace_1, trace_2 in pr.zip_streams(read(event_1.path), read(event_2.path)):
                station = trace_1.stats.station
                station_coordinates = pr.get_coordinates(inventory, station)
                event_coordinates = (event_1.latitude + event_2.latitude) / 2, \
                                    (event_1.longitude + event_2.longitude) / 2
                pick_p_1, pick_s_1 = pr.get_picks(event_1, event_coordinates, station_coordinates, trace_1, parameters,
                                                  cli_args.phase_file, cli_args.phase_type, get_travel_times)
                pick_p_2, pick_s_2 = pr.get_picks(event_2, event_coordinates, station_coordinates, trace_2, parameters,
                                                  cli_args.phase_file, cli_args.phase_type, get_travel_times)

                cc_shift_raw, _ = pr.correlate_waves(trace_1.data, trace_2.data, parameters['max_shift'],
                                                     normalize=None)

                pick_p_mean_delay = pr.relative_pick_time(trace_1.stats, trace_2.stats, pick_p_1, pick_p_2,
                                                          cc_shift_raw)
                pick_s_mean_delay = pr.relative_pick_time(trace_1.stats, trace_2.stats, pick_s_1, pick_s_2,
                                                          cc_shift_raw)

                pr.sync_traces(trace_1, trace_2, cc_shift_raw)

                trace_1_filtered = pr.cc_preprocess(trace_1, pick_p_mean_delay, pick_s_mean_delay,
                                                    freq_range, parameters['full_waveform_window'])
                trace_2_filtered = pr.cc_preprocess(trace_2, pick_p_mean_delay, pick_s_mean_delay,
                                                    freq_range, parameters['full_waveform_window'])

                cc_shift_filtered, cc_value_unnormalized = pr.correlate_waves(trace_1_filtered.data,
                                                                              trace_2_filtered.data,
                                                                              parameters['max_shift'],
                                                                              normalize=None)
                cc_value = pr.normalize_cc(trace_1_filtered.data, trace_2_filtered.data, cc_value_unnormalized,
                                           cc_shift_filtered)

                if cc_value < parameters['cross-correlation_threshold']:
                    logging.debug(f"Cross-correlation {cc_value} lower than threshold for {station}")
                    continue

                if cli_args.graphics_dir:
                    if not os.path.exists(graphics_dir):
                        os.makedirs(graphics_dir)

                    starttime = parameters['full_waveform_window'][0] - pick_p_mean_delay
                    pr.sync_traces(trace_1_filtered, trace_2_filtered, cc_shift_filtered)
                    pr.plot_signals(trace_1_filtered, trace_2_filtered,
                                    event_1.magnitude, event_2.magnitude,
                                    starttime + pick_p_mean_delay - parameters['p_waveform_window'][0],
                                    starttime + pick_p_mean_delay + parameters['p_waveform_window'][1],
                                    starttime + pick_s_mean_delay - parameters['s_waveform_window'][0],
                                    starttime + pick_s_mean_delay + parameters['s_waveform_window'][1],
                                    cc_value,
                                    freq_range,
                                    graphics_dir / f"{event_1.name}_{event_2.name}_{station}_CCZ."
                                                   f"{cli_args.graphics_format}")

                try:
                    if cli_args.graphics_dir:
                        cs_graphics_path = graphics_dir / f"{event_1.name}_{event_2.name}_{station}" \
                                                          f"_CSP.{cli_args.graphics_format}"
                    else:
                        cs_graphics_path = None
                    time_delay_p = pr.cross_spectrum_analysis(trace_1, trace_2, pick_p_mean_delay,
                                                              parameters['p_waveform_window'],
                                                              freq_range, parameters['max_shift'], parameters,
                                                              cs_graphics_path)
                except BaseException as exception:
                    logging.debug(f"An error occurred during processing of P-wave for {station}", exc_info=exception)
                    continue

                try:
                    if cli_args.graphics_dir:
                        cs_graphics_path = graphics_dir / f"{event_1.name}_{event_2.name}_{station}" \
                                                          f"_CSS.{cli_args.graphics_format}"
                    else:
                        cs_graphics_path = None
                    time_delay_s = pr.cross_spectrum_analysis(trace_1, trace_2, pick_s_mean_delay,
                                                              parameters['s_waveform_window'],
                                                              freq_range, parameters['max_shift'], parameters,
                                                              cs_graphics_path)
                except BaseException as exception:
                    logging.debug(f"An error occurred during processing of S-wave for {station}", exc_info=exception)
                    continue

                delta_sp = time_delay_s - time_delay_p
                if abs(delta_sp) < delta_sp_threshold:
                    REs[(event_1.Index, event_2.Index)][station] = (cc_value, delta_sp)

            if len(REs[(event_1.Index, event_2.Index)]) < min_stations:
                del REs[(event_1.Index, event_2.Index)]
            elif cli_args.graphics_dir:
                pr.plot_similarity(event_1, event_2, REs, inventory, parameters['cross-correlation_threshold'],
                                   delta_sp_threshold,
                                   graphics_dir / f"{event_1.name}-{event_2.name}.{cli_args.graphics_format}")

        except BaseException as exception:
            if cli_args.stop:
                raise exception
            else:
                logging.warning(f"An error occurred while processing pair ({event_1.path}, {event_2.path})",
                                exc_info=exception)

    families = pr.connected_components(REs.keys())
    logging.info(f"RE families: {', '.join(str([catalogue.loc[n, 'name'] for n in family]) for family in families)}")
    cli_args.output_dir.mkdir(parents=True, exist_ok=True)
    for i, family in enumerate(families):
        with open(cli_args.output_dir / f"{i}_event.sel", "w") as file:
            for n in family:
                t_err, h_err, v_err = map(lambda s: err if (err := errors[n].get(s)) else parameters['default_' + s],
                                          ['time_uncertainty', 'horizontal_uncertainty', 'vertical_uncertainty'])
                file.write(f"{catalogue.loc[n, 'date'].strftime('%Y%m%d  %H%M%S%f')}   "
                           f"{catalogue.loc[n, 'latitude']:.4f}     {catalogue.loc[n, 'longitude']:.4f}    "
                           f"{catalogue.loc[n, 'depth']:.3f}   {catalogue.loc[n, 'magnitude']:.1f}    "
                           f"{t_err:.2f}    {h_err:.2f}   {v_err:.2f}        "
                           f"{catalogue.loc[n, 'name']}\n")
        if len(family) > 2:
            with open(cli_args.output_dir / f"{i}_dt.cc", "w") as file:
                for (t1, t2) in combinations(family, 2):
                    if REs[(t1, t2)]:
                        file.write(f"#    {catalogue.loc[t1, 'name']}    {catalogue.loc[t2, 'name']}     0.0\n")
                        for station in sorted(REs[(t1, t2)]):
                            cc, delta_sp = REs[(t1, t2)][station]
                            delta_v = parameters['p_wave_speed'] - parameters['s_wave_speed']
                            ttp = parameters['s_wave_speed'] * delta_sp / delta_v
                            file.write(f"{station}     {ttp: 10.9f}    {cc:.2f}    P\n")
                            tts = -parameters['p_wave_speed'] * delta_sp / delta_v
                            file.write(f"{station}     {tts: 10.9f}    {cc:.2f}    S\n")

    toc = timer()
    logging.debug(f"Elapsed time: {toc - tic:.2f} seconds")
