#!/usr/bin/env python
"""
Convert a BPCH file (or files) to a CF-compliant, NetCDF dataset.

This script is a simple utility for opening (and optionally concatenating) BPCH
files and then immediately writing them out to disk in NetCDF format. It's a
thin wrapper designed to avoid having to drop into an interactive Python session
to accomplish this task.

"""

import os, sys

from xbpch import open_bpchdataset, open_mfbpchdataset
from dask.diagnostics import ProgressBar

from argparse import ArgumentParser, RawDescriptionHelpFormatter
parser = ArgumentParser(description=__doc__,
                        formatter_class=RawDescriptionHelpFormatter)
parser.add_argument("bpch_files", type=str, nargs="+",
                    help="Paths to BPCH file(s) to load (, concatenate)"
                         " and write back to disk")
parser.add_argument("output_nc", type=str,
                    help="Name of output file to write")
parser.add_argument("-t", "--tracerinfo", metavar="tracerinfo.dat",
                    type=str, default="tracerinfo.dat",
                    help="Path to tracerinfo.dat, if not in current directory")
parser.add_argument("-d", "--diaginfo", metavar="diaginfo.dat",
                    type=str, default="diaginfo.dat",
                    help="Path to diaginfo.dat, if not in current directory")

def _maybe_del_attr(da, attr):
    """ Possibly delete an attribute on a DataArray if it's present """
    if attr in da.attrs:
        del da.attrs[attr]

    return da


def _maybe_decode_attr(da, attr):
    # TODO: Fix this so that bools get written as attributes just fine
    """ Possibly coerce an attribute on a DataArray to an easier type
    to write to disk. """
    # bool -> int
    if (attr in da.attrs) and (type(da.attrs[attr] == bool)):
        da.attrs[attr] = int(da.attrs[attr])

    return da


if __name__ == "__main__":

    args = parser.parse_args()

    # Check if the output already exists; if so, exit immediately
    if os.path.exists(args.output_nc):
        print("ERROR: Can't write to output file that already exists.")
        sys.exit(1)

    # Check that all input files exist
    bad_files = [fn for fn in args.bpch_files if not os.path.exists(fn)]
    if bad_files:
        print("ERROR: Couldn't find the following input files:")
        for fn in bad_files:
            print("   " + fn)
        sys.exit(1)

    # Else, we should be good to read in and concatenate
    open_kws = {
        "tracerinfo_file": args.tracerinfo,
        "diaginfo_file": args.diaginfo,
        "memmap": True, "dask": True
    }

    print("\nReading in file(s)...")
    if len(args.bpch_files) == 1:
        ds = open_bpchdataset(args.bpch_files[0], **open_kws)
    else:
        ds = open_mfbpchdataset(args.bpch_files, **open_kws)

    # This block of code is hack to fix the encoding of attributes
    # on the DataArrays in this Dataset. They are being
    # set at a very low level when we read in the data, and manually
    # specifying the encoding doesn't work.
    # However, deleting them from the attributes dict
    # doesn't end up removing them from the final output file - they get
    # written just fine.
    print("\nDecoding variables...")
    for v in ds.data_vars:
        da = ds[v]
        da = _maybe_del_attr(da, 'scale_factor')
        da = _maybe_del_attr(da, 'units')
        da = _maybe_decode_attr(da, 'hydrocarbon')
        da = _maybe_decode_attr(da, 'chemical')
    # Also delete attributes on time.
    times = ds.time
    times = _maybe_del_attr(times, 'units')

    print("\nWriting to " + args.output_nc + " ...")
    with ProgressBar():
        ds.to_netcdf(args.output_nc)

    print("\n done!")

