#!/usr/bin/env python
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from __future__ import (absolute_import, unicode_literals, division,
                        print_function)
import sys
from maltpynt.mp_base import mp_root
from maltpynt.mp_io import mp_get_file_type, is_string
import collections
import matplotlib.pyplot as plt
import numpy as np
import argparse


def normalize(array, ref=0):
    m = ref

    std = np.std(array)

    newarr = np.zeros_like(array)

    good = array > m

    newarr[good] = (array[good] - ref) / std
    return newarr


def usage():
    print('''Usage:

          MPdumpdyn filename

          where filename is any file in a valid MaLTPyNT format for PDS or CPDS
          ''')


def dumpdyn(fname, plot=False):
    ftype, pdsdata = mp_get_file_type(fname, specify_reb=False)

    dynpds = pdsdata['dyn' + ftype]
    edynpds = pdsdata['edyn' + ftype]

    try:
        freq = pdsdata['freq']
    except:
        flo = pdsdata['flo']
        fhi = pdsdata['fhi']
        freq = (fhi + flo) / 2

    time = pdsdata['dyntimes']
    freqs = np.zeros_like(dynpds)
    times = np.zeros_like(dynpds)

    for i, im in enumerate(dynpds):
        freqs[i, :] = freq
        times[i, :] = time[i]

    t = times.real.flatten()
    f = freqs.real.flatten()
    d = dynpds.real.flatten()
    e = edynpds.real.flatten()

    np.savetxt('{}_dumped_{}.txt'.format(mp_root(fname), ftype),
               np.array([t, f, d, e]).T)
    size = normalize(d)
    if plot:
        plt.scatter(t, f, s=size)
        plt.xlabel('Time (s)')
        plt.ylabel('Freq (Hz)')

        plt.show()

description = ('Dump dynamical (cross) power spectra')
parser = argparse.ArgumentParser(description=description)

parser.add_argument("files", help="List of (c)PDS files", nargs='+')
parser.add_argument("--noplot", help="plot results",
                    default=False, action='store_true')

if __name__ == '__main__':
    args = parser.parse_args()

    fnames = args.files

    if len(fnames) == 0:
        usage()
        sys.exit()

    for f in fnames:
        dumpdyn(f, plot=not args.noplot)
