#!/usr/bin/env python

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2017 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

from __future__ import print_function

import sys
import getopt
import os
import os.path
import time
import datetime
import logging
import dill as pickle
import numpy as np
import scipy.stats as stats
import TransportMaps as TM
import TransportMaps.CLI as TMCLI
import TransportMaps.Maps as MAPS
import TransportMaps.Distributions as DIST
import TransportMaps.XML as TMXML
import TransportMaps.Algorithms.Adaptivity as ALGADPT

sys.path.append(os.getcwd())

# Data storage object
stg = type('', (), {})()

def usage():
    usage_str = """
Usage: tmap-tm [-h -f -I] 
  --dist=DIST --output=OUTPUT [--base-dist=BASE_DIST]
  (--mtype=MTYPE --span=SPAN --btype=BTYPE --order=ORDER)
    / (--map-descr=MAP_DESCR)
  --qtype=QTYPE --qnum=QNUM
  [--tol=TOL --with-reg=REG --ders=DERS --fungrad --adapt=none]
  [--laplace-pull]
  [--log=LOG --nprocs=NPROCS --batch=BATCH]
"""
    print(usage_str)

def description():
    docs_monotone_str = \
        '  --mtype=MTYPE           monotone format for the transport\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_MONOTONE,'                          ')
    docs_span_str = \
        '  --span=SPAN             span type for all the components of the map\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_SPAN,'                          ')
    docs_btype_str = \
        '  --btype=BTYPE           basis types for all the components of the map\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_BTYPE,'                          ')
    docs_qtype_str = \
        '  --qtype=QTYPE           quadrature type for the discretization of ' + \
        'the KL-divergence\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_QTYPE,'                          ')
    docs_ders_str = \
        '  --ders=DERS             derivatives to be used in the optimization\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_DERS,'                          ')
    docs_adaptivity_str = \
        '  --adapt=ADAPT           adaptivity algorithm for map construction\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_ADAPTIVITY,'                          ')
    docs_log_str = \
        '  --log=LOG               log level (default=30). Uses package logging.\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_LOGGING,'                          ')

    docs_str = """DESCRIPTION
Given a file (--dist) storing the target distribution, produce the transport map that
pushes forward the base distribution (default: standard normal) to the target distribution.
All files involved are stored and loaded using the python package dill.

OPTIONS - input/output:
  --dist=DIST             path to the file containing the target distribution 
  --output=OUTPUT         path to the output file containing the transport map,  
                          the base distribution, the target distribution and all 
                          the additional parameters used for the construction 
  --base-dist=BASE_DIST   path to the file containing the base distribution
                          (default: a standard normal of suitable dimension)
OPTIONS - map description (using default maps):
""" + docs_monotone_str + docs_span_str + docs_btype_str + \
"""  --order=ORDER           order of the transport map
OPTIONS - map description (manual):
  --map-descr=MAP_DESCR   XML file containing the skeleton of the transport map
OPTIONS - solver:
""" + docs_qtype_str + \
"""  --qnum=QNUM             quadrature level
  --tol=TOL               optimization tolerance (default: 1e-4)
  --reg=REG               a float L2 regularization parameter
                          (default: no regularization)
""" + docs_ders_str + \
"""  --fungrad               whether the distributions provide a method to compute
                          the log pdf and its gradient at the same time
""" + docs_adaptivity_str + \
"""  --laplace-pull          whether to precondition pulling back the target through
                          its Laplace approximation
""" + docs_log_str + \
"""  --nprocs=NPROCS         number of processors to be used (default=1)
  --batch=BATCH           list of batch sizes for function evaluation, gradient
                          evaluation and Hessian evaluation
OPTIONS - other:
  -f                      force overwrite of OUTPUT file
  -I                      enter interactive mode after finishing
  -h                      print this help
"""
    print(docs_str)

def full_usage():
    usage()

def full_doc():
    full_usage()
    description()

##################### INPUT PARSING #####################
argv = sys.argv[1:]
INTERACTIVE = False
# I/O
DIST_FNAME = None
OUT_FNAME = None
BASE_DIST_FNAME = None
FORCE = False
# Map type
MONOTONE = None
SPAN = None
BTYPE = None
ORDER = None
MAP_DESCR = None
ADAPT = 'none'
# Quadrature type
stg.QTYPE = None
stg.QNUM = None
# Solver options
stg.TOL = 1e-4
stg.REG = None
stg.DERS = 2
stg.FUNGRAD = False
# Pre-pull Laplace
stg.LAPLACE_PULL = False
# Logging
LOGGING_LEVEL = 30 # Warnings
# Parallelization
NPROCS = 1
BATCH_SIZE = [None, None, None]
try:
    opts, args = getopt.getopt(argv,"hfI",
                               [
                                   # I/O
                                   "dist=", "output=", "base-dist=",
                                   # Map type
                                   "mtype=", "span=", "btype=", "order=",
                                   "map-descr=",
                                   # Quadrature type
                                   "qtype=", "qnum=",
                                   # Solver options
                                   "tol=", "reg=", "ders=", "fungrad",
                                   # Adaptivity options
                                   "adapt=",
                                   # Whether to pre-pull through Laplace
                                   "laplace-pull",
                                   # Logging
                                   "log=",
                                   # Parallelization and batching option
                                   "nprocs=", "batch="
                               ])
except getopt.GetoptError:
    full_usage()
    raise
for opt, arg in opts:
    if opt == '-h':
        full_doc()
        sys.exit()

    # Force overwrite
    elif opt == '-f':
        FORCE = True
        
    # Interactive
    elif opt == '-I':
        INTERACTIVE = True
        
    # I/O
    elif opt in ['--dist']:
        DIST_FNAME = arg
    elif opt in ['--output']:
        OUT_FNAME = arg
    elif opt in ['--base-dist']:
        BASE_DIST_FNAME = arg
        
    # Map type
    elif opt in ['--mtype']:
        MONOTONE = arg
    elif opt in ['--span']:
        SPAN = arg
    elif opt in ['--btype']:
        BTYPE = arg
    elif opt in ['--order']:
        ORDER = int(arg)
    elif opt in ['--map-descr']:
        MAP_DESCR = arg

    # Quadrature type
    elif opt in ['--qtype']:
        stg.QTYPE = int(arg)
    elif opt in ['--qnum']:
        stg.QNUM = [int(q) for q in arg.split(',')]
        
    # Solver options
    elif opt in ['--tol']:
        stg.TOL = float(arg)
    elif opt in ['--reg']:
        stg.REG = {'type': 'L2',
                   'alpha': float(arg)}
    elif opt in ['--ders']:
        stg.DERS = int(arg)
    elif opt == '--fungrad':
        stg.FUNGRAD = True

    # Adaptivity options
    elif opt == '--adapt':
        ADAPT = arg

    # Pre-pull Laplace
    elif opt in ['--laplace-pull']:
        stg.LAPLACE_PULL = True

    # Logging
    elif opt in ['--log']:
        LOGGING_LEVEL = int(arg)

    # Parallelization and batching
    elif opt in ['--nprocs']:
        NPROCS = int(arg)
    elif opt in ['--batch']:
        BATCH_SIZE = [int(b) for b in arg.split(',')]
        
    else:
        raise ValueError("Option %s not recognized" % opt)

def tstamp_print(msg, *args, **kwargs):
    tstamp = datetime.datetime.fromtimestamp(
        time.time()
    ).strftime('%Y-%m-%d %H:%M:%S')
    print(tstamp + " " + msg, *args, **kwargs)
        
# Check for required arguments
if None in [DIST_FNAME, OUT_FNAME]:
    usage()
    tstamp_print("ERROR: Options --dist and --output must be specified")
    sys.exit(3)
if None in [stg.QTYPE, stg.QNUM]:
    usage()
    tstamp_print("ERROR: Options --qtype and --qnum must be specified")
    sys.exit(3)
if stg.QTYPE < 3:
    stg.QNUM = stg.QNUM[0]
map_descr_list = [MONOTONE, SPAN, BTYPE, ORDER]
if MAP_DESCR is None and None in map_descr_list:
    usage()
    tstamp_print("ERROR: Either options --mtype, --span, --btype, " + \
          "--order are specified or option --map-descr is specified")
    sys.exit(3)
elif MAP_DESCR is not None and not all([s is None for s in map_descr_list]):
    usage()
    tstamp_print("ERROR: Either options --mtype, --span, --btype, " + \
          "--order are specified or option --map-descr is specified")
    sys.exit(3)
if ADAPT not in TMCLI.AVAIL_ADAPTIVITY:
    usage()
    tstamp_print("ERROR: adaptivity algorithm not recognized")
    sys.exit(3)
BATCH_SIZE = [None] * (stg.DERS + 1 - len(BATCH_SIZE)) + BATCH_SIZE

if not FORCE and os.path.exists(OUT_FNAME):
    sel = ''
    while sel not in ['y', 'Y', 'n', 'N', 'q']:
        if sys.version_info[0] == 3:
            sel = input("The file %s already exists. " % OUT_FNAME + \
                        "Do you want to overwrite? [y/N/q] ")
        else:
            sel = raw_input("The file %s already exists. " % OUT_FNAME + \
                            "Do you want to overwrite? [y/N/q] ")
    if sel == 'n' or sel == 'N' or sel == 'q':
        tstamp_print("Terminating.")
        sys.exit(0)

try:
    ##################### DATA LOADING #####################
    TM.setLogLevel(LOGGING_LEVEL)

    # Load target distribution
    with open(DIST_FNAME,'rb') as istr:
        stg.target_distribution = pickle.load(istr)
    dim = stg.target_distribution.dim

    # Load base distribution
    if BASE_DIST_FNAME is None:
        stg.base_distribution = DIST.StandardNormalDistribution(dim)
    else:
        with open(BASE_DIST_FNAME,'rb') as istr:
            stg.base_distribution = pickle.load(istr)

    # Instantiate Transport Map
    if MAP_DESCR is not None:
        if ADAPT == 'sequential':
            usage()
            tstamp_print("ERROR: option --adapt=sequential is not available with --map-descr (yet..)")
            sys.exit(3)
        tm_approx = TMXML.load_xml(MAP_DESCR)
    else:
        if MONOTONE == 'linspan':
            map_constructor = TM.Default_IsotropicMonotonicLinearSpanTriangularTransportMap
        elif MONOTONE == 'intexp':
            map_constructor = TM.Default_IsotropicIntegratedExponentialTriangularTransportMap
        elif MONOTONE == 'intsq':
            map_constructor = TM.Default_IsotropicIntegratedSquaredTriangularTransportMap
        else:
            raise ValueError("Monotone type not recognized (linspan|intexp|intsq)")
        if ADAPT == 'none':
            tm_approx = map_constructor(dim, ORDER, span=SPAN, btype=BTYPE)
            logging.info("Number coefficients: %d" % tm_approx.n_coeffs)
        elif ADAPT == 'sequential':
            tm_approx = [map_constructor(dim, o, span=SPAN, btype=BTYPE) for o in range(1,ORDER+1)]
            n_coeffs = sum( tm.n_coeffs for tm in tm_approx )
            logging.info("Number coefficients: %d" % n_coeffs )

    # Laplace pullback
    if stg.LAPLACE_PULL:
        laplace_approx = TM.laplace_approximation( stg.target_distribution )
        lapmap = MAPS.LinearTransportMap.build_from_Gaussian( laplace_approx )
        tar = DIST.PullBackTransportMapDistribution( lapmap, stg.target_distribution )
    else:
        tar = stg.target_distribution

    # Set up adaptivity algorithm
    if ADAPT == 'none':
        builder_constructor = ALGADPT.KullbackLeiblerBuilder
    elif ADAPT == 'sequential':
        builder_constructor = ALGADPT.SequentialKullbackLeiblerBuilder
    builder = builder_constructor(stg.base_distribution, tar)
    
    # Minimize KL-divergence
    mpi_pool = None
    if NPROCS > 1:
        mpi_pool = TM.get_mpi_pool()
        mpi_pool.start(NPROCS)

    try:
        (tm_approx, log) = builder.solve(
            tm_approx, qtype=stg.QTYPE, qparams=stg.QNUM, regularization=stg.REG,
            tol=stg.TOL, ders=stg.DERS, fungrad=stg.FUNGRAD,
            batch_size=BATCH_SIZE, mpi_pool=mpi_pool)
    finally:
        if mpi_pool is not None:
            mpi_pool.stop()

    if stg.LAPLACE_PULL:
        stg.tmap = MAPS.CompositeTransportMap(lapmap, tm_approx)
    else:
        stg.tmap = tm_approx

    stg.approx_base_distribution = DIST.PullBackTransportMapDistribution(stg.tmap, stg.target_distribution)
    stg.approx_target_distribution = DIST.PushForwardTransportMapDistribution(stg.tmap, stg.base_distribution)

    # STORE
    with open(OUT_FNAME, 'wb') as out_stream:
        pickle.dump(stg, out_stream)

finally:
    if INTERACTIVE:
        from IPython import embed
        embed()
