#!/usr/bin/env python

from __future__ import print_function
from argparse import ArgumentParser

def load_data(filename):
    """Reads a MESA/GYRE style data file and returns the scalar and vector
    data in two structured arrays.

    Parameters
    ----------
    filename: str
        Filename to load.

    Returns
    -------
    header: structured array
        Scalar data.  i.e. numbers for which there is only one value
        in the data. e.g. initial parameters for a MESA history, the
        total mass in a MESA profile or GYRE summary, the mode
        frequency of a GYRE mode file, etc.

    data: structured array
        Vector data.  i.e. numbers for which there are multiple
        entries.  e.g.  luminosity over an evolutionary run in a MESA
        history, the density as a function of radius in a MESA
        profile, mode frequencies in a GYRE summary, etc.

    """
    
    with open(filename, 'r') as f:
        lines = [line.encode('utf-8') for line in f.readlines()]

    header = np.genfromtxt(lines[1:3], names=True)
    data = np.genfromtxt(lines[5:], names=True)

    return header, data

def get_column(data, key):
    """Gets data in column `key` from the NumPy record array `data`.  If
    the desired key isn't available, the function will see if the
    logarithm (base 10) version is available in `log_` and use that.
    Similarly, if the requested variable starts with `log_`, the
    function will see if the non-logarithmic version is available, and
    return that.

    Parameters
    ----------
    data: numpy.recarray
        Record array containing data from file, as returned by
        `load_data`.
    key: str
        Name of column for which to retrieve data.

    Returns
    -------
    column: numpy.recarray
        Array containing data for the desired column.

    """
    if key in data.dtype.names:
        return data[key]
    else:
        if ('log_' + key) in data.dtype.names:
            return 10.**data['log_' + key]
        else:
            if key.startswith('log_') and key[4:] in data.dtype.names:
                return np.log10(data[key[4:]])
            else:
                            raise KeyError("'%s' is not a column in file '%s'. " % (kx, filename) +
                           "Available columns are: " + 
                           ", ".join(data.dtype.names))
        

parser = ArgumentParser( description="""``qpmg`` is a simple Python
script to quickly inspect output that adheres to the format used by
MESA's profiles and histories and GYRE's summaries and mode files.
While ``qpmg`` provides many options, it's intended for quick
inspection rather than publication-quality plots.

To see the list of available columns in a file, run ``qpmg`` on a given
file.  The defaults will cause an error that displays the available
columns.""")
parser.add_argument('filenames', type=str, nargs='+')
parser.add_argument('-x', type=str, nargs='+', default=[''])
parser.add_argument('-y', type=str, nargs='+', default=[''],
                    help="Column(s) to use for the x and y variables.  "
                    "The code loops through "
                    "however many x and y keys you give (inner loop "
                    "over y, outer loop over x) but most of the time "
                    "you probably only want one x variable.")
parser.add_argument('--xlabel', type=str, nargs='+', default=None)
parser.add_argument('--ylabel', type=str, nargs='+', default=None,
                    help="Overrides the axis label with the given string.  "
                    "Accepts spaces. i.e. 'effective temperature' is OK.  "
                    "Default is to use the first argument of -x/-y.")
parser.add_argument('--legend', type=str, nargs='+', default=None,
                    help="If 'auto', add a legend using the filenames as "
                    "keys.  Otherwise, use the arguments as a list of keys "
                    "(default is no legend).")
parser.add_argument('--style', type=str, default='-',
                    help="point style, passed to plot function (default='-')")
# parser.add_argument('--exp10-x', action='store_const', const=True, default=False,
#                     help="raise x-axis to power of 10")
# parser.add_argument('--exp10-y', action='store_const', const=True, default=False,
#                     help="raise y-axis to power of 10")
# parser.add_argument('--log10-x', action='store_const', const=True, default=False,
#                     help="take log10 of x-axis")
# parser.add_argument('--log10-y', action='store_const', const=True, default=False,
#                     help="take log10 of y-axis")
parser.add_argument('--scale-x', type=float, default=1.0,
                    help="multiply variables on x-axis by this much (default=1)")
parser.add_argument('--scale-y', type=float, default=1.0,
                    help="multiply variables on y-axis by this much (default=1)")
parser.add_argument('--plotter', type=str, default='plot',
                    choices=['plot', 'semilogx', 'semilogy', 'loglog'],
                    help="use 'matplotlib.pyplot.plotter' to plot (default='plot')")
parser.add_argument('--title', type=str, nargs='+', default=[''],
                    help="Adds the given title to the plot.  Accepts spaces. "
                    "i.e. 'my plot' is OK.  Default is no title.")
parser.add_argument('--style-file', type=str, default=None,
                    help="Specifies a matplotlib style file to load.")
parser.add_argument('--rcParams', type=str, nargs='+', default=[],
                    help="Any parameters in `matplotlib.pyplot.rcParams`, "
                    "provided in the form `key` `value`. "
                    "e.g. --rcParams text.usetex True fig.dpi 300")
parser.add_argument('-v', '--verbose', action='store_true',
                    help="Print diagnostic information as plot is made.")
args = parser.parse_args()

def vprint(*print_args, **print_kwargs):
    if args.verbose:
        print(*print_args, **print_kwargs)

vprint('Importing libraries... ', end='', flush=True)
import numpy as np
from matplotlib import pyplot as pl
vprint('Done.')

if args.style_file:
    vprint('Applying style file %s... ' % args.style_file,
           end='', flush=True)
    pl.style.use(args.style_file)
    vprint('Done.')
else:
    vprint('No style file requested.')

vprint("Selecting plotter '%s'..." % args.plotter)
if args.plotter == 'plot':
    plotter = pl.plot
elif args.plotter == 'semilogx':
    plotter = pl.semilogx
elif args.plotter == 'semilogy':
    plotter = pl.semilogy
elif args.plotter == 'loglog':
    plotter = pl.loglog
else:
    raise ValueError("invalid choice for --plotter "
                     "(but this should've been caught by argparse)")
    

vprint('Parsing `--rcParams`...')
i = 0
k = False
all_keys = list(pl.rcParams.keys())
while i < len(args.rcParams):
    if not k:
        k = args.rcParams[i]
        i += 1
        continue

    s = args.rcParams[i]
    if s in all_keys:
        raise ValueError("I didn't get an argument for "
                         "`rcParam` %s. " % k)

    if type(pl.rcParams[k]) is bool:
        if s.lower() in ['true', 't']:
            s = True
        elif s.lower() in ['false', 'f']:
            s = False
        else:
            raise ValueError("rcParam %s expects a boolean. "
                             "Please use `T`, `True`, `F` "
                             "or `False` (case insensitive)")

    pl.rcParams[k] = type(pl.rcParams[k])(s)
    
    k = False
    i += 1

for filename in args.filenames:
    vprint('Loading file %s... ' % filename, end='', flush=True)
    header, data = load_data(filename)
    vprint('Done.')
    
    for kx in args.x:
        x = get_column(data, kx)

        for ky in args.y:
            y = get_column(data, ky)
            
            plotter(x*args.scale_x, y*args.scale_y,
                    args.style, label=filename)

if args.xlabel:
    pl.xlabel(' '.join(args.xlabel))
else:
    pl.xlabel(args.x[0])
        
if args.ylabel:
    pl.ylabel(' '.join(args.ylabel))
else:
    pl.ylabel(args.y[0])

if args.legend:
    if args.legend[0] == 'auto':
        pl.legend()
    else:
        pl.legend(args.legend)

pl.title(' '.join(args.title))

vprint('Plot ready.')
pl.show()
vprint('Script finished.  Exiting normally.')
