#! /usr/bin/env python

import argparse
import logging
import os
from collections import defaultdict
from functools import lru_cache
from itertools import combinations
from pathlib import Path

import numpy as np
from multitaper import MTCross
from obspy import read, read_inventory
from obspy.taup import TauPyModel
from obspy.taup.taup_create import build_taup_model
from tqdm import tqdm
from yaml import load, dump

try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper

import findres as fres

parser = argparse.ArgumentParser(prog=__file__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('catalogue', type=Path, help="Modified ZMAP catalogue containing repeater candidates, "
                                                 "their MSEED location, and names")
parser.add_argument('inventory', type=Path, help="Inventory of the stations data")
parser.add_argument('parameters', type=Path, help="Numerical parameters file")
parser.add_argument('output', type=Path, help="Output YAML file")
parser.add_argument('--phase_file', type=Path, default=None,
                    help="Catalogue containing the phase picking and other information if available")
parser.add_argument('--phase_type',
                    choices=['hypoinv', 'nll', 'quakeml', 'hypoel'] + list(fres.custom_formats.registry),
                    default=None, help="Type of PHASE_FILE")
parser.add_argument('--picker', help="Enable picker", default=False, action='store_true')
parser.add_argument('--taup_model', type=Path, default=None,
                    help="Velocity model file without extension (assumed to be .tvel) if available")
parser.add_argument('--rebuild_model', default=False, action='store_true',
                    help="Force the rebuild of the velocity model (regenerate ObsPy local cache)")
parser.add_argument('--graphics_dir', default=None, type=Path,
                    help="Where to put the graphics relative to OUTPUT_DIR, "
                         "if not specified the graphics is not generated")
parser.add_argument('--graphics_format', default='pdf',
                    help="Graphics format, must be one of the extensions recognized by matplotlib")
parser.add_argument('--hypodd', default=False, action='store_true',
                    help="Whether to output hypodd input files")
parser.add_argument('--include', default=None, help="Include only the listed events from catalogue")
parser.add_argument('--log', help="Log level", default='warning')
parser.add_argument('--progress', help="Show progress bar", default=False, action='store_true')

cli_args = parser.parse_args()

if cli_args.phase_file is None and cli_args.taup_model is None:
    parser.error("Either a phase file or a model must be provided")

if cli_args.phase_file is not None and cli_args.phase_type is None:
    parser.error("You must provide a type for the given phase file")

logging.basicConfig(format='%(levelname)s-%(asctime)s: %(message)s',
                    level=getattr(logging, cli_args.log.upper()))

if __name__ == '__main__':
    with cli_args.inventory.open('r') as file:
        inventory = read_inventory(file)
    with cli_args.parameters.open('r') as file:
        parameters = load(file, Loader=Loader)
    catalogue = fres.read.zmap(cli_args.catalogue, extensions=['name', 'path'])
    if cli_args.include:
        include_list = [int(x) for x in cli_args.include.split(',')]
        catalogue = catalogue.loc[catalogue['name'].isin(include_list)]

    if cli_args.taup_model:
        if cli_args.rebuild_model:
            build_taup_model(str(cli_args.taup_model.with_suffix('.tvel')))
        try:
            get_travel_times = lru_cache()(TauPyModel(model=str(cli_args.taup_model)).get_travel_times)
        except FileNotFoundError:
            build_taup_model(str(cli_args.taup_model.with_suffix('.tvel')))
            get_travel_times = lru_cache()(TauPyModel(model=str(cli_args.taup_model)).get_travel_times)
    else:
        get_travel_times = None

    REs = defaultdict(dict)
    for event_1, event_2 in tqdm(combinations(catalogue.itertuples(), 2),
                                 total=len(catalogue) * (len(catalogue) - 1) // 2,
                                 disable=not cli_args.progress):
        logging.debug(f"Analyzing pair {event_1.name}, {event_2.name}")
        graphics_dir = cli_args.output.parent / f"{cli_args.graphics_dir}/{event_1.name}_{event_2.name}/"
        freq_range, delta_sp_threshold, min_stations = fres.utils.piecewise_from_thresholds(max(event_1.magnitude,
                                                                                                event_2.magnitude),
                                                                                            parameters['thresholds'])
        for trace_1, trace_2 in tqdm(fres.utils.zip_streams(read(event_1.path), read(event_2.path)),
                                     desc=f"({event_1.name}, {event_2.name})",
                                     leave=False,
                                     disable=not cli_args.progress):
            station = trace_1.stats.station
            network = trace_1.stats.network
            event_coordinates = (event_1.latitude + event_2.latitude) / 2, \
                                (event_1.longitude + event_2.longitude) / 2
            try:
                station_coordinates_1 = fres.read.coordinates(inventory, network, station, event_1.date,
                                                              tol=parameters['stations_distance_tolerance'])
                pick_p_1, pick_s_1 = fres.read.picks(event_1, event_coordinates, station_coordinates_1, trace_1,
                                                     parameters,
                                                     cli_args.phase_file, cli_args.phase_type,
                                                     cli_args.picker, parameters.get('picker_arguments'),
                                                     get_travel_times)
                station_coordinates_2 = fres.read.coordinates(inventory, network, station, event_2.date,
                                                              tol=parameters['stations_distance_tolerance'])
                pick_p_2, pick_s_2 = fres.read.picks(event_2, event_coordinates, station_coordinates_2, trace_2,
                                                     parameters,
                                                     cli_args.phase_file, cli_args.phase_type,
                                                     cli_args.picker, parameters.get('picker_arguments'),
                                                     get_travel_times)
            except fres.read.InventoryLookupError as err:
                logging.warning(f"Inventory lookup error for {network + '.' + station} "
                                f"while analyzing ({event_1.name}, {event_2.name}). Skipping.")
                logging.debug("Exception info.", exc_info=err)
                continue
            except fres.read.MissingPhaseDataError as err:
                logging.warning(f"Phases lookup error for {network + '.' + station} "
                                f"while analyzing ({event_1.name}, {event_2.name}). Skipping.")
                logging.debug("Exception info.", exc_info=err)
                continue
            except fres.read.SPickEstimationError as err:
                logging.warning(f"S pick estimation error for {network + '.' + station} "
                                f"while analyzing ({event_1.name}, {event_2.name}). Skipping.")
                logging.debug("Exception info.", exc_info=err)
                continue

            try:
                cc_shift_raw, _ = fres.analyze.correlate_waves(trace_1, trace_2, parameters['max_shift'])

                pick_p_mean_delay = fres.utils.relative_pick_time(trace_1.stats, trace_2.stats, pick_p_1, pick_p_2,
                                                                  cc_shift_raw)
                pick_s_mean_delay = fres.utils.relative_pick_time(trace_1.stats, trace_2.stats, pick_s_1, pick_s_2,
                                                                  cc_shift_raw)

                trace_1_filtered = fres.utils.cc_preprocess(trace_1, pick_p_mean_delay, pick_s_mean_delay,
                                                            freq_range, parameters['full_cc_waveform_window'])
                trace_2_filtered = fres.utils.cc_preprocess(trace_2, pick_p_mean_delay, pick_s_mean_delay,
                                                            freq_range, parameters['full_cc_waveform_window'])

                cc_shift, cc_value = fres.analyze.correlate_waves(trace_1_filtered,
                                                                  trace_2_filtered,
                                                                  parameters['max_shift'])
                trace_2.trim(starttime=trace_2.stats.starttime - cc_shift * trace_2.stats.delta,
                             endtime=trace_2.stats.endtime - cc_shift * trace_2.stats.delta,
                             pad=True, fill_value=0.0)

            except fres.analyze.EmptyOrConstantTraceError as err:
                logging.warning(f"Empty or constant trace error for {network + '.' + station} "
                                f"while analyzing ({event_1.name}, {event_2.name}). Skipping.")
                logging.debug("Exception info.", exc_info=err)
                continue

            if cc_value < parameters['cross-correlation_threshold']:
                logging.debug(f"Cross-correlation {cc_value} lower than threshold for {network + '.' + station}")
                continue

            if cli_args.graphics_dir:
                if not os.path.exists(graphics_dir):
                    os.makedirs(graphics_dir)

                starttime = parameters['full_cc_waveform_window'][0] - pick_p_mean_delay
                trace_2_filtered.trim(starttime=trace_2_filtered.stats.starttime - cc_shift * trace_2.stats.delta,
                                      endtime=trace_2_filtered.stats.endtime - cc_shift * trace_2.stats.delta,
                                      pad=True, fill_value=0.0)
                fres.plot.signals(trace_1_filtered, trace_2_filtered,
                                  starttime + pick_p_mean_delay - parameters['p_cs_waveform_window'][0],
                                  starttime + pick_p_mean_delay + parameters['p_cs_waveform_window'][1],
                                  starttime + pick_s_mean_delay - parameters['s_cs_waveform_window'][0],
                                  starttime + pick_s_mean_delay + parameters['s_cs_waveform_window'][1],
                                  f"M={event_1.magnitude:.2f}-{event_2.magnitude:.2f} "
                                  f"BP={freq_range[0]}-{freq_range[1]}Hz CC={cc_value:.2f}",
                                  graphics_dir / f"{event_1.name}_{event_2.name}_{network + '.' + station}_CCZ."
                                                 f"{cli_args.graphics_format}")

            try:
                trace_1p = trace_1.copy()
                fres.utils.trim_window(trace_1p, pick_p_mean_delay, parameters['p_cs_waveform_window'])
                trace_2p = trace_2.copy()
                fres.utils.trim_window(trace_2p, pick_p_mean_delay, parameters['p_cs_waveform_window'])

                p_shift, _ = fres.analyze.correlate_waves(trace_1p, trace_2p, parameters['max_shift'])
                if p_shift != 0:
                    trace_2p.trim(starttime=trace_2p.stats.starttime - p_shift * trace_2p.stats.delta,
                                  endtime=trace_2p.stats.endtime - p_shift * trace_2p.stats.delta,
                                  pad=True, fill_value=0.0)
                cross_spectrum = MTCross(trace_1p.data,
                                         trace_2p.data,
                                         dt=trace_1p.stats.delta,
                                         nw=parameters['mt_coherence_time_bandwidth_product'],
                                         kspec=parameters['mt_coherence_num_taper'],
                                         nfft=int((trace_1p.stats.npts // 2) + 1),
                                         iadapt=1)
                slope_p, indices_p = fres.analyze.find_slope(cross_spectrum, freq_range, parameters)
                time_delay_p = slope_p / (2 * np.pi) - trace_1p.stats.delta * (p_shift + cc_shift)
                if cli_args.graphics_dir:
                    fres.plot.cross_spectrum(trace_1p, trace_2p, cross_spectrum, slope_p, indices_p,
                                             f"delay={time_delay_p:0.6f} (subsample={slope_p / (2 * np.pi):0.6f})",
                                             graphics_dir / f"{event_1.name}_{event_2.name}_{network + '.' + station}"
                                                            f"_CSP.{cli_args.graphics_format}")

                trace_1s = trace_1.copy()
                fres.utils.trim_window(trace_1s, pick_s_mean_delay, parameters['s_cs_waveform_window'])
                trace_2s = trace_2.copy()
                fres.utils.trim_window(trace_2s, pick_s_mean_delay, parameters['s_cs_waveform_window'])

                s_shift, _ = fres.analyze.correlate_waves(trace_1s, trace_2s, parameters['max_shift'])
                if s_shift != 0:
                    trace_2s.trim(starttime=trace_2s.stats.starttime - s_shift * trace_2s.stats.delta,
                                  endtime=trace_2s.stats.endtime - s_shift * trace_2s.stats.delta,
                                  pad=True, fill_value=0.0)
                cross_spectrum = MTCross(trace_1s.data,
                                         trace_2s.data,
                                         dt=trace_1s.stats.delta,
                                         nw=parameters['mt_coherence_time_bandwidth_product'],
                                         kspec=parameters['mt_coherence_num_taper'],
                                         nfft=int((trace_1s.stats.npts // 2) + 1),
                                         iadapt=1)
                slope_s, indices_s = fres.analyze.find_slope(cross_spectrum, freq_range, parameters)
                time_delay_s = slope_s / (2 * np.pi) - trace_1s.stats.delta * (s_shift + cc_shift)
                if cli_args.graphics_dir:
                    fres.plot.cross_spectrum(trace_1s, trace_2s, cross_spectrum, slope_s, indices_s,
                                             f"delay={time_delay_s:0.6f} (subsample={slope_s / (2 * np.pi):0.6f})",
                                             graphics_dir / f"{event_1.name}_{event_2.name}_{network + '.' + station}"
                                                            f"_CSS.{cli_args.graphics_format}")

            except fres.analyze.LowCoherenceError as err:
                logging.debug(f"No coherence found for {network + '.' + station} "
                              f"while analyzing ({event_1.name}, {event_2.name}). Skipping.")
                logging.debug("Exception info.", exc_info=err)
                continue
            except fres.analyze.EmptyOrConstantTraceError as err:
                logging.warning(f"Empty or constant trace error for {network + '.' + station} "
                                f"while analyzing ({event_1.name}, {event_2.name}). Skipping.")
                logging.debug("Exception info.", exc_info=err)
                continue

            delta_sp = time_delay_s - time_delay_p
            if abs(delta_sp) < delta_sp_threshold:
                REs[(event_1.Index, event_2.Index)][(network, station)] = (cc_value, delta_sp)
        if len(REs[(event_1.Index, event_2.Index)]) < min_stations:
            del REs[(event_1.Index, event_2.Index)]
        elif cli_args.graphics_dir:
            fres.plot.similarity(event_1, event_2, REs, inventory, parameters['cross-correlation_threshold'],
                                 delta_sp_threshold,
                                 graphics_dir / f"{event_1.name}-{event_2.name}.{cli_args.graphics_format}")
    if REs:
        families = fres.utils.connected_components(REs.keys())
        logging.info(
            f"RE families: {', '.join(str([catalogue.loc[n, 'name'] for n in family]) for family in families)}")
        cli_args.output.parent.mkdir(parents=True, exist_ok=True)
        with cli_args.output.with_suffix('.yaml').open("w") as file:
            data = {
                f"{n, list(map(lambda i: catalogue.loc[i, 'name'], family))}": {
                    f"{catalogue.loc[a, 'name'], catalogue.loc[b, 'name']}": {
                        f"{network + '.' + station}": {'cross-correlation': float(cc), 'delta_sp': float(dsp)}
                        for (network, station), (cc, dsp) in REs.get((a, b)).items()}
                    for a, b in REs if a in family or b in family} for n, family in enumerate(families)}
            dump(data, file, Dumper=Dumper)
        if cli_args.hypodd:
            errors = fres.read.errors(catalogue, cli_args.phase_file, cli_args.phase_type)
            fres.write.dump_hypodd(families, catalogue, errors, REs, parameters, cli_args.output.parent)
