#!/usr/bin/env python3

# This file is part of tf-mdp.

# tf-mdp is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# tf-mdp is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with tf-mdp. If not, see <http://www.gnu.org/licenses/>.


import argparse
import time


def parse_args():
    description = 'Probabilistic planning in continuous state-action MDPs using TensorFlow.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('rddl', type=str, help='RDDL filepath')
    parser.add_argument(
        '-c', '--channels',
        type=int,
        required=True,
        help='number of channels in policy network'
    )
    parser.add_argument(
        '-l', '--layers',
        nargs='+',
        required=True,
        help='number of units in each hidden layer in policy network (default=4)'
    )
    parser.add_argument(
        '-ln', '--layer-norm',
        action='store_true',
        help='layer normalization flag'
    )
    parser.add_argument(
        '-ld', '--logdir',
        type=str, default='/tmp/tfmdp',
        help='log directory for data summaries (default=/tmp/tfmdp)'
    )
    parser.add_argument(
        '-b', '--batch-size',
        type=int, default=1024,
        help='number of trajectories in a batch (default=1024)'
    )
    parser.add_argument(
        '-hr', '--horizon',
        type=int, default=40,
        help='number of timesteps (default=40)'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int, default=500,
        help='number of timesteps (default=500)'
    )
    parser.add_argument(
        '-lr', '--learning-rate',
        type=float, default=0.001,
        help='optimizer learning rate (default=0.001)'
    )
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def print_parameters(args):
    if args.verbose:
        import tfmdp
        print()
        print('Running tf-mdp v{} ...'.format(tfmdp.__version__))
        print()
        print('>> RDDL: {}'.format(args.rddl))
        print()
        print('>> Policy Net:')
        print('channels  = {}'.format(args.channels))
        print('layers    = [{}]'.format(','.join(args.layers)))
        print('layernorm = {}'.format(args.layer_norm))
        print()
        print('>> Training:')
        print('logdir        = {}'.format(args.logdir))
        print('batch size    = {}'.format(args.batch_size))
        print('horizon       = {}'.format(args.horizon))
        print('epochs        = {}'.format(args.epochs))
        print('learning rate = {}'.format(args.learning_rate))
        print()


def read_file(path):
    with open(path, 'r') as f:
        return f.read()


def parse_rddl(path):
    from pyrddl.parser import RDDLParser
    parser = RDDLParser()
    parser.build()
    rddl = parser.parse(read_file(path))
    return rddl


def compile(rddl):
    from tfrddlsim.rddl2tf.compiler import Compiler
    rddl2tf = Compiler(rddl, batch_mode=True)
    return rddl2tf


def solve(args):
    from tfmdp.planner import PolicyOptimizationPlanner
    planner = PolicyOptimizationPlanner(rddl2tf, args.channels, args.layers, args.layer_norm, args.logdir)
    planner.build(args.learning_rate, args.batch_size, args.horizon)
    start = time.time()
    policy, logdir = planner.run(args.epochs)
    end = time.time()
    return policy, logdir, end - start


if __name__ == '__main__':

    # parse CLI arguments
    args = parse_args()

    # print planner parameters
    print_parameters(args)

    # read RDDL file
    rddl = parse_rddl(args.rddl)

    # compile RDDL to TensorFlow
    rddl2tf = compile(rddl)

    # run planner
    policy, logdir, uptime = solve(args)
    print()
    print('Done in {:.6f} sec.'.format(uptime))

    print()
    print('tensorboard --logdir {}'.format(logdir))
