#!python

import argparse
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from astrotf.radio import FilterEngine

# ---------------------------------------------------------------------------------
# Command line arguments
# ---------------------------------------------------------------------------------

parser = argparse.ArgumentParser(
    formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    description='Process Amber trigger files.'
)

parser.add_argument(
    "-i",
    "--input",
    nargs='?',
    type=str,
    const="CB??.trigger",
    default="CB??.trigger",
    help="input filename(s). Wildcards are allowed."
)

parser.add_argument(
    "-o",
    "--output",
    nargs='?',
    type=str,
    help="output filename. Writes to stdout when omitted."
)

parser.add_argument(
    "-p",
    "--plot",
    nargs='?',
    type=str,
    help="Optional plot filename."
)
parser.add_argument(
    "-fh",
    "--freq_hi_mhz",
    nargs='?',
    type=float,
    const=1549.8,
    default=1549.8,
    help="Highest observation frequency (MHz)"
)

parser.add_argument(
    "-fl",
    "--freq_lo_mhz",
    nargs='?',
    type=float,
    const=1249.8,
    default=1249.8,
    help="Lowest observation frequency (MHz)"
)

parser.add_argument(
    "-dt",
    "--sample_time",
    nargs='?',
    type=float,
    const=8.192e-05,
    default=8.192e-05,
    help="Sampling time (seconds)"
)

parser.add_argument(
    "-of",
    "--output_format",
    nargs='?',
    type=int,
    const=0,
    default=0,
    help="output format. 0: Amber Classifier",
)

parser.add_argument(
    "-v",
    "--verbose",
    help="modify output verbosity",
    action="store_true"
)

args = parser.parse_args()


def read_trigger_file(filename, verbose=True):
    if verbose:
        print('start reading:', filename)

    # First read the header line, and extract column names
    with open(filename, "r") as f:
        header = f.readline()

        # correct some column name that have spaces
        header = header.lower().replace(" (s)", "")

        # remove pultiple spaces
        header = ' '.join(header.split())

        # split on space
        colnames = header.strip().split(' ')

        # remove '#' is
        if colnames[0] == '#':
            colnames.pop(0)

    # Read the data
    triggers = pd.read_csv(
        filename,
        delim_whitespace=True,
        names=colnames,
        skiprows=1,
        header=None,
        comment='#'
    )

    # fix column names some more
    triggers.rename(index=str, inplace=True, columns={
        "beam": "beam_id",
        "batch": "batch_id",
        "sample": "sample_id",
        "sigma": "snr"
    })

    if verbose:
        print('finished reading:', filename)

    return triggers


# -----------------------------------------------------------------------------------------------------
# Read the input files
# -----------------------------------------------------------------------------------------------------
if args.verbose:
    print('Looking for input files:', args.input)

# Process and read all trigger files
input_chunk = []
for f in glob.glob(args.input):
    input_chunk.append(read_trigger_file(f, args.verbose))

if len(input_chunk) == 0:
    print('No input files found.')
    exit(1)

# merge the files into a big one
data = pd.concat(input_chunk, axis=0, ignore_index=True)

if args.verbose:
    print('Finished reading and merging all trigger files. Read {} triggers'.format(data.shape[0]))


# -----------------------------------------------------------------------------------------------------
# enrich the data
# -----------------------------------------------------------------------------------------------------
data['w'] = data.integration_step * args.sample_time


# -----------------------------------------------------------------------------------------------------
# Filters trigger,  first process triggers of same widths
# -----------------------------------------------------------------------------------------------------
output_chunks = []
widths = data.w.unique()
for w in widths:

    dataw = data.loc[data['w'] == w].copy()
    if args.verbose:
        print('Processing: {} triggers of width {}'.format(dataw.shape[0], w))

    eng = FilterEngine(args.freq_lo_mhz, args.freq_hi_mhz, buffer_size=512, nn_size=32, tol=1E-4)

    eng.sort(dataw, ['time', 'w', 'dm'])

    output_df = pd.DataFrame(
        [
            t for t in eng.filter(
                (e.time, e.w, e.dm, e.snr, e.beam_id, e.sample_id, e.integration_step)
                for e in dataw.itertuples()
            )
        ],
        columns=['time', 'w', 'dm', 'snr', 'beam_id', 'sample_id', 'integration_step']
    )

    output_chunks.append(output_df)
    if args.verbose:
        print("Finished processing:", eng.num_in, "->", eng.num_out)

output = pd.concat(output_chunks, axis=0, ignore_index=True)

if args.verbose:
    print("Finished processing all widths. First round reduction to {} triggers".format(output.shape[0]))


# -----------------------------------------------------------------------------------------------------
# Merge triggers of separate widths and filter once again
# -----------------------------------------------------------------------------------------------------
if args.verbose:
    print('Processing merged triggers...')

eng = FilterEngine(args.freq_lo_mhz, args.freq_hi_mhz, buffer_size=512, nn_size=16, tol=1E-4)

eng.sort(output, ['time', 'w', 'dm'])

output_round2 = pd.DataFrame(
    [
        t for t in eng.filter(
            (e.time, e.w, e.dm, e.snr, e.beam_id, e.sample_id,  e.integration_step)
            for e in output.itertuples()
        )
    ],
    columns=['time', 'w','dm', 'snr', 'beam_id', 'sample_id', 'integration_step']
)

if args.verbose:
    print('Finished processing:', eng.num_in, "->", eng.num_out)


# -----------------------------------------------------------------------------------------------------
# Write output results
# -----------------------------------------------------------------------------------------------------
if args.output:

    if args.verbose:
        print("Writing results to", args.output)

    if args.output_format == 0:
        np.savetxt(
            args.output,
            np.concatenate([
                output_round2.snr.values,
                output_round2.dm.values,
                output_round2.time.values,
                output_round2.integration_step.values,
                output_round2.sample_id.values
            ]).reshape(5, -1)
        )
    else:
        print("Unknown output format option")
        exit(1)
else:

    for e in output.itertuples():
        print(e.time, e.w, e.dm, e.snr, e.beam_id, e.sample_id, e.integration_step)


# -----------------------------------------------------------------------------------------------------
# make a plot
# -----------------------------------------------------------------------------------------------------
if args.plot:
    if args.verbose:
        print("Generating plots.")

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(22, 14), sharex=True, sharey=True)
    ax1.scatter(data.time, data.dm, s=20, lw=0,  c='k')
    ax2.scatter(output_round2.time, output_round2.dm, s=20, lw=0, c='k')
    ax1.set_ylabel('DM')
    ax1.set_xlabel('time (s)')

    ax1.set_title('input {}'.format(data.shape[0]))
    ax2.set_title('output {}'.format(output_round2.shape[0]))

    ax1.set_yscale('log')
    ax2.set_yscale('log')

    ax1.set_ylim(0.1, 4000)
    ax2.set_ylim(0.1, 4000)

    plt.tight_layout()
    plt.savefig(args.plot)

if args.verbose:
    print("done.")
