#!python

import argparse
import matplotlib.pyplot as plt
import pandas as pd
import glob
from astrotf.radio import FilterEngine


# ---------------------------------------------------------------------------------
# Command line arguments
# ---------------------------------------------------------------------------------

parser = argparse.ArgumentParser(description='Process Amber trigger files.')

parser.add_argument(
    "-i",
    "--input",
    nargs='?',
    type=str,
    const="CB??.trigger",
    default="CB??.trigger",
    help="input filename(s). Wildcards are allowed."
)

parser.add_argument(
    "-o",
    "--output",
    nargs='?',
    type=str,
    help="output filename. Writes to stdout when omitted."
)

parser.add_argument(
    "-p",
    "--plot",
    nargs='?',
    type=str,
    help="Optional plot filename."
)
parser.add_argument(
    "-fh",
    "--freq_hi_mhz",
    nargs='?',
    type=float,
    const=1549.8,
    default=1549.8,
    help="Highest observation frequency (MHz)"
)

parser.add_argument(
    "-fl",
    "--freq_lo_mhz",
    nargs='?',
    type=float,
    const=1249.8,
    default=1249.8,
    help="Lowest observation frequency (MHz)"
)

parser.add_argument(
    "-dt",
    "--sample_time",
    nargs='?',
    type=float,
    const=8.192e-05,
    default=8.192e-05,
    help="Sampling time (seconds)"
)

parser.add_argument(
    "-v",
    "--verbose",
    help="modify output verbosity",
    action="store_true"
)


args = parser.parse_args()

def read_trigger_file(filename, verbose=True):
    if verbose:
        print('start reading:', filename)

    # First read the header line, and extract column names
    with open(filename, "r") as f:
        header = f.readline()

        # correct some column name that have spaces
        header = header.lower().replace(" (s)", "")

        # remove pultiple spaces
        header = ' '.join(header.split())

        # split on space
        colnames = header.strip().split(' ')

        # remove '#' is
        if colnames[0] == '#':
            colnames.pop(0)

    # Read the data
    triggers = pd.read_csv(
        filename,
        delim_whitespace=True,
        names=colnames,
        skiprows=1,
        header=None,
        comment='#'
    )

    # fix column names some more
    triggers.rename(index=str, inplace=True, columns={
        "beam": "beam_id",
        "batch": "batch_id",
        "sample": "sample_id",
        "sigma": "snr"
    })

    if verbose:
        print('finished reading:', filename)

    return triggers


# --------------------------------------------------------------------
# Main
# --------------------------------------------------------------------

if args.verbose:
    print('Looking for input files:', args.input)

# Process and read all trigger files
input_chunk = []
for f in glob.glob(args.input):
    input_chunk.append( read_trigger_file(f, args.verbose) )

if len(input_chunk) == 0:
    print('No input files found.')
    exit(1)

# merge the files into a big one
data = pd.concat(input_chunk, axis=0, ignore_index=True)

if args.verbose:
    print('Finished reading and merging all trigger files. Read {} triggers'.format(data.shape[0]))

# enrich the data
data['w'] = data.integration_step * args.sample_time


# Process trigger of individual widths
output_chunks = []
widths = data.w.unique()
for w in widths:

    dataw =  data.loc[data['w'] == w].copy()
    if args.verbose:
        print('Processing: {} triggers of width {}'.format(dataw.shape[0], w))

    eng = FilterEngine(
        args.freq_lo_mhz,
        args.freq_hi_mhz,
        buffer_size=512,
        nn_size=32,
        tol=1E-4
    )

    eng.sort(dataw, ['time', 'w', 'dm'])

    output_df = pd.DataFrame(
        [
            t for t in eng.filter(
                (
                    e.time,
                    e.w,
                    e.dm,
                    e.snr,
                    e.beam_id,
                    e.sample_id,
                    e.integration_step
                ) for e in dataw.itertuples()
            )
        ],
        columns=[
            'time',
            'w',
            'dm',
            'snr',
            'beam_id',
            'sample_id',
            'integration_step'
        ]
    )

    output_chunks.append(output_df)
    if args.verbose:
        print("Finished processing:", eng.num_in, "->", eng.num_out)


output = pd.concat(output_chunks, axis=0, ignore_index=True)
if args.verbose:
    print("Finished processing all widths. First round reduction to {} triggers".format(output.shape[0]))

if args.verbose:
    print('Processing merged triggers...')

eng = FilterEngine(
    args.freq_lo_mhz,
    args.freq_hi_mhz,
    buffer_size=512,
    nn_size=16,
    tol=1E-4
)

eng.sort(output, ['time', 'w', 'dm'])

output_round2 = pd.DataFrame(
    [
        t for t in eng.filter(
        (
            e.time,
            e.w,
            e.dm,
            e.snr,
            e.beam_id,
            e.sample_id,
            e.integration_step
        ) for e in output.itertuples()
    )
    ],
    columns=[
        'time',
        'w',
        'dm',
        'snr',
        'beam_id',
        'sample_id',
        'integration_step'
    ]
)

if args.verbose:
    print('Finished processing:', eng.num_in, "->", eng.num_out)

if args.output:
    if args.verbose:
        print("Writing results to", args.output)
    output_round2.to_csv(args.output, sep=" ", index=False)
else:

    for e in output.itertuples():
        print(
            e.time,
            e.w,
            e.dm,
            e.snr,
            e.beam_id,
            e.sample_id,
            e.integration_step
        )

if args.plot:
    if args.verbose:
        print("Generating plots.")

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(22,14), sharex=True, sharey=True)
    ax1.scatter(
        data.time,
        data.dm,
        s=20,
        lw=0,
        c='k',
        alpha=1
    )
    ax2.scatter(
        output_round2.time,
        output_round2.dm,
        s=20,
        lw=0,
        c='k',
        alpha=1
    )
    ax1.set_ylabel('DM')
    ax1.set_xlabel('time (s)')

    ax1.set_title('input {}'.format(data.shape[0]))
    ax2.set_title('output {}'.format(output_round2.shape[0]))

    ax1.set_yscale('log')
    ax2.set_yscale('log')

    ax1.set_ylim(0.1, 4000)
    ax2.set_ylim(0.1, 4000)

    plt.tight_layout()
    plt.savefig(args.plot)

if args.verbose:
    print("done.")
