#!/usr/bin/env python3

""" Command line utility to convert SQLite to NetCDF """

import argparse
import os
import sqlite3
import sys
from datetime import datetime
import netCDF4 as nc
import numpy as np


def arguments():
    """Function captures the command-line aguments passed to this script"""

    description = """
    Program for converting .db file format to NetCDF format.

    For help, contact John.Krasting@noaa.gov
    """

    parser = argparse.ArgumentParser(
        description=description, formatter_class=argparse.RawTextHelpFormatter
    )

    # -- Input tile
    parser.add_argument(
        "infile", type=str, help="Input file. Format must be sqlite (*.db)"
    )

    # -- Output file
    parser.add_argument(
        "-o",
        "--outfile",
        type=str,
        default="out.nc",
        help="Output file. Default name is out.nc",
    )

    parser.add_argument(
        "-F",
        "--force",
        action="store_true",
        default=False,
        help="Clobber existing output file if it exists.",
    )

    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        default=False,
        help="Verbose output. Default is quiet output.",
    )

    return parser.parse_args()


def check_file(filepath, clobber=False):
    """
    Checks if output file exists. If clobber set to True, the existing file
    is removed
    """
    if os.path.exists(filepath):
        if clobber is True:
            os.remove(filepath)
        else:
            raise ValueError(
                "FATAL: output file exits. Trying using the "
                + "'-F' option to force overwrite."
            )


def tables_and_years(dbfile):
    """
    Function to get list of variables and years in a .db file
    """
    conn = sqlite3.connect(dbfile)
    cur = conn.cursor()
    sql = "SELECT name FROM sqlite_master WHERE type='table'"
    _ = cur.execute(sql)
    tables = [str(record[0]) for record in cur.fetchall()]
    cur.close()
    conn.close()

    years = []
    for table in tables:
        if table not in ["long_name", "units", "cell_measure"]:
            conn = sqlite3.connect(dbfile)
            cur = conn.cursor()
            sql = "SELECT year FROM " + table
            try:
                _ = cur.execute(sql)
            except RuntimeError as error:
                sys.stderr.write("Unable to process " + table)
                sys.stderr.write("Variable " + table + " does not contain year axis")
                raise error
            data = [int(record[0]) for record in cur.fetchall()]
            cur.close()
            conn.close()
            years = years + data

    years = list(set(years))

    return tables, years


def write_nc(
    dbfile,
    outfile,
    tables,
    years,
    clobber=False,
    ncformat="NETCDF3_CLASSIC",
    verbose=False,
):
    """Writes output to a netCDF file"""
    check_file(outfile, clobber=clobber)
    dt_obj = datetime.now()
    timestamp_str = dt_obj.strftime("%d-%b-%Y (%H:%M:%S.%f)")
    ncfile = nc.Dataset(outfile, "w", format=ncformat)
    ncfile.setncattr("source_file", dbfile)
    ncfile.setncattr("created", timestamp_str)
    # ncfile.setncattr('experiment',expName)
    # ncfile.setncattr('type',plotType)
    # ncfile.setncattr('region',region)
    _ = ncfile.createDimension("time", 0)
    time = ncfile.createVariable("time", "f4", ("time",))
    time.calendar = "noleap"
    time.units = "days since 0001-01-01 00:00:00.0"
    time[:] = ((np.array(years) - 1) * 365.0) + 196.0
    for table in tables:
        if table not in ["long_name", "units", "cell_measure"]:
            data_array = np.ma.ones(len(years)) + 1.0e20
            data_array.mask = True
            # if 'Land' in plotType:
            #  extract_list = ['avg','sum']
            # else:
            extract_list = ["value"]
            for k in extract_list:
                count = 0
                conn = sqlite3.connect(dbfile)
                cur = conn.cursor()

                cur.execute(f"SELECT value from long_name where var='{table}'")
                result = cur.fetchone()
                long_name = result[0] if result is not None else ""

                cur.execute(f"SELECT value from units where var='{table}'")
                result = cur.fetchone()
                units = result[0] if result is not None else ""

                if verbose is True:
                    print("Processing %s  :  %s" % (table, long_name))
                for year in years:
                    sql = "SELECT " + k + " FROM " + table + " where year=" + str(year)
                    _ = cur.execute(sql)
                    data = [float(record[0]) for record in cur.fetchall()]
                    if data:
                        data_array[count] = data[0]
                        data_array.mask[count] = False
                    else:
                        pass
                    count = count + 1
                cur.close()
                conn.close()
                if len(extract_list) > 1:
                    outname = table + "_" + k.replace("sum", "int")
                else:
                    outname = table
            var = ncfile.createVariable(outname, "f4", ("time"))
            if long_name != "":
                var.long_name = long_name
            if units != "":
                var.units = units
            var[:] = data_array
    ncfile.close()


if __name__ == "__main__":

    # read command line arguments
    args = arguments()
    # get the full path of the db file
    infile = os.path.realpath(args.infile)
    # get list of variables and years in a file
    _tables, _years = tables_and_years(infile)
    write_nc(
        infile, args.outfile, _tables, _years, clobber=args.force, verbose=args.verbose
    )

sys.exit()
