#!python

import argparse
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys

from astrotf.radio import FilterEngine

# ---------------------------------------------------------------------------------
# Command line arguments
# ---------------------------------------------------------------------------------

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    description='Process Amber trigger files.'
)

parser.add_argument(
    "-i",
    "--input",
    nargs='?',
    type=str,
    const="CB??.trigger",
    default="CB??.trigger",
    help="input filename(s). Wildcards are allowed."
)

parser.add_argument(
    "-r",
    "--ref",
    nargs='?',
    type=str,
    help="Reference trigger file for plotting"
)

parser.add_argument(
    "-o",
    "--output",
    nargs='?',
    type=str,
    help="output filename. Writes to stdout when omitted."
)

parser.add_argument(
    "-p",
    "--plot",
    nargs='?',
    type=str,
    help="Optional plot filename."
)

parser.add_argument(
    "-pt0",
    "--plot_t0",
    nargs='?',
    type=float,
    help="Start plot time."
)

parser.add_argument(
    "-pt1",
    "--plot_t1",
    nargs='?',
    type=float,
    help="End plot time."
)


parser.add_argument(
    "-fm0",
    "--filter_dm0",
    nargs='?',
    type=float,
    const=20.0,
    default=20.0,
    help="lower dm bound, values below will be ignored."
)

parser.add_argument(
    "-pdm0",
    "--plot_dm0",
    nargs='?',
    type=float,
    help="lower dm bound in plot."
)


parser.add_argument(
    "-pdm1",
    "--plot_dm1",
    nargs='?',
    type=float,
    help="Upper dm bound in plot."
)


parser.add_argument(
    "-pdml",
    "--plot_dmlog",
    help="Use logarithmic dm axis.",
    action="store_true"
)

parser.add_argument(
    "-ps",
    "--plot_size",
    nargs='?',
    type=float,
    const=1.0,
    default=1.0,
    help="Plot marker size factor."
)

parser.add_argument(
    "-fh",
    "--freq_hi_mhz",
    nargs='?',
    type=float,
    const=1549.8,
    default=1549.8,
    help="Highest observation frequency (MHz)"
)

parser.add_argument(
    "-fl",
    "--freq_lo_mhz",
    nargs='?',
    type=float,
    const=1249.8,
    default=1249.8,
    help="Lowest observation frequency (MHz)"
)

parser.add_argument(
    "-dt",
    "--sample_time",
    nargs='?',
    type=float,
    const=8.192e-05,
    default=8.192e-05,
    help="Sampling time (seconds)"
)

parser.add_argument(
    "-of",
    "--output_format",
    nargs='?',
    type=int,
    const=0,
    default=0,
    help="output format. 0: Amber Classifier",
)

parser.add_argument(
    "-v",
    "--verbose",
    help="modify output verbosity",
    action="store_true"
)

args = parser.parse_args()


def read_trigger_file(filename, verbose=True):
    if verbose:
        print('start reading:', filename)

    # First read the header line, and extract column names
    with open(filename, "r") as f:
        header = f.readline()

        # correct some column name that have spaces
        header = header.lower().replace(" (s)", "")

        # remove pultiple spaces
        header = ' '.join(header.split())

        # split on space
        colnames = header.strip().split(' ')

        # remove '#' is
        if colnames[0] == '#':
            colnames.pop(0)

    # Read the data
    triggers = pd.read_csv(
        filename,
        delim_whitespace=True,
        names=colnames,
        skiprows=1,
        header=None,
        comment='#'
    )

    # fix column names some more
    triggers.rename(index=str, inplace=True, columns={
        "beam": "beam_id",
        "batch": "batch_id",
        "sample": "sample_id",
        "sigma": "snr"
    })

    if verbose:
        print('finished reading:', filename)

    return triggers


# -----------------------------------------------------------------------------------------------------
# Read the input files
# -----------------------------------------------------------------------------------------------------
if args.verbose:
    print('Looking for input files:', args.input)

# Process and read all trigger files
input_chunk = []
for f in glob.glob(args.input):
    input_chunk.append(read_trigger_file(f, args.verbose))

if len(input_chunk) == 0:
    print('No input files found.')
    exit(1)

# merge the files into a big one
data = pd.concat(input_chunk, axis=0, ignore_index=True)

if args.verbose:
    print('Finished reading and merging all trigger files. Read {} triggers'.format(data.shape[0]))


# -----------------------------------------------------------------------------------------------------
# enrich the data
# -----------------------------------------------------------------------------------------------------
data['w'] = data.integration_step * args.sample_time


# -----------------------------------------------------------------------------------------------------
# Filters trigger,  first process triggers of same widths
# -----------------------------------------------------------------------------------------------------
output_chunks = []
widths = data.w.unique()
for w in widths:

    dataw = data.loc[data['w'] == w].copy()
    if args.verbose:
        print('Processing: {} triggers of width {}'.format(dataw.shape[0], w))

    eng = FilterEngine(args.freq_lo_mhz, args.freq_hi_mhz, buffer_size=512, nn_size=32, tol=1E-3)

    eng.sort(dataw, ['time', 'w', 'dm'])

    output_df = pd.DataFrame(
        [
            t for t in eng.filter(
                (e.time, e.w, e.dm, e.snr, e.beam_id, e.sample_id, e.integration_step)
                for e in dataw.itertuples()
            )
        ],
        columns=['time', 'w', 'dm', 'snr', 'beam_id', 'sample_id', 'integration_step']
    )

    output_chunks.append(output_df)
    if args.verbose:
        print("Finished processing:", eng.num_in, "->", eng.num_out)

output = pd.concat(output_chunks, axis=0, ignore_index=True)

if args.verbose:
    print("Finished processing all widths. First round reduction to {} triggers".format(output.shape[0]))


# -----------------------------------------------------------------------------------------------------
# Merge triggers of separate widths and filter once again
# -----------------------------------------------------------------------------------------------------
if args.verbose:
    print('Processing merged triggers...')

eng = FilterEngine(args.freq_lo_mhz, args.freq_hi_mhz, buffer_size=512, nn_size=64, tol=1E-3)

eng.sort(output, ['time', 'w', 'dm'])

output_round2 = pd.DataFrame(
    [
        t for t in eng.filter(
            (e.time, e.w, e.dm, e.snr, e.beam_id, e.sample_id,  e.integration_step)
            for e in output.itertuples()
        )
    ],
    columns=['time', 'w','dm', 'snr', 'beam_id', 'sample_id', 'integration_step']
)

if args.verbose:
    print('Finished processing:', eng.num_in, "->", eng.num_out)

# -----------------------------------------------------------------------------------------------------
# Filter output
# -----------------------------------------------------------------------------------------------------
if args.filter_dm0:
    # lower dm filter
    output_round2 = output_round2.loc[output_round2.dm >= args.filter_dm0]
# -----------------------------------------------------------------------------------------------------
# Write output results
# -----------------------------------------------------------------------------------------------------
if args.output:

    if args.verbose:
        print("Writing results to", args.output)

    if args.output_format == 0:
        np.savetxt(
            args.output,
            np.concatenate([
                output_round2.snr.values,
                output_round2.dm.values,
                output_round2.time.values,
                output_round2.integration_step.values,
                output_round2.sample_id.values
            ]).reshape(5, -1)
        )
    else:
        print("Unknown output format option")
        exit(1)
else:

    for e in output.itertuples():
        print(e.time, e.w, e.dm, e.snr, e.beam_id, e.sample_id, e.integration_step)


# -----------------------------------------------------------------------------------------------------
# make a plot
# -----------------------------------------------------------------------------------------------------
if args.plot:

    if args.verbose:
        print("Generating plots.")

    if args.ref:
        fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(16, 14), sharex=True, sharey=True)
    else:
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10), sharex=True, sharey=True)

    ax1.scatter(
        data.time,
        data.dm,
        s=args.plot_size * 0.2 * np.minimum(data.snr, 30)**2,
        lw=1,
        edgecolor='k',
        c='r'
    )
    ax1.set_title('input {}'.format(data.shape[0]), fontsize=16)

    ax2.scatter(
        output_round2.time,
        output_round2.dm,
        s=args.plot_size * 0.2 * np.minimum(output_round2.snr, 30)**2,
        lw=1,
        edgecolor='k',
        c='r'
    )
    ax2.set_title('output {}'.format(output_round2.shape[0]), fontsize=16)
    if args.ref:
        ref_data = read_trigger_file(args.ref, args.verbose)
        ax3.scatter(
            ref_data.time,
            ref_data.dm,
            s=args.plot_size * 0.2 * np.minimum(ref_data.snr, 30) ** 2,
            lw=1,
            edgecolor='k',
            c='r'
        )
        ax3.set_title('reference {}'.format(ref_data.shape[0]), fontsize=16)


    for ax in fig.axes:
        ax.grid()
        ax.set_ylabel('DM', fontsize=16)
        ax.set_xlabel('time (s)', fontsize=16)

    if args.plot_dmlog:
        ax1.set_yscale('log')
        ax2.set_yscale('log')
    if args.plot_t0:
        plt.gca().set_xlim(left=args.plot_t0)
    if args.plot_t1:
        plt.gca().set_xlim(right=args.plot_t1)
    if args.plot_dm0:
        plt.gca().set_ylim(bottom=args.plot_dm0)
    if args.plot_dm1:
        plt.gca().set_ylim(top=args.plot_dm1)

    plt.tight_layout()
    plt.savefig(args.plot)

if args.verbose:
    print("done.")
