#!/usr/bin/env python3

# This file is part of tf-mdp.

# tf-mdp is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# tf-mdp is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with tf-mdp. If not, see <http://www.gnu.org/licenses/>.


import rddlgym

import argparse
import time


def parse_args():
    description = 'Probabilistic planning in continuous state-action MDPs using TensorFlow.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument(
        'rddl',
        type=str,
        help='RDDL file or rddlgym domain id'
    )
    parser.add_argument(
        '-l', '--layers',
        nargs='+',
        default=[],
        help='number of units in each hidden layer in policy network'
    )
    parser.add_argument(
        '-ln', '--layer-norm',
        action='store_true',
        help='layer normalization flag'
    )
    parser.add_argument(
        '-ld', '--logdir',
        type=str, default='/tmp/tfmdp',
        help='log directory for data summaries (default=/tmp/tfmdp)'
    )
    parser.add_argument(
        '-b', '--batch-size',
        type=int, default=256,
        help='number of trajectories in a batch (default=256)'
    )
    parser.add_argument(
        '-hr', '--horizon',
        type=int, default=40,
        help='number of timesteps (default=40)'
    )
    parser.add_argument(
        '-e', '--epochs',
        type=int, default=200,
        help='number of timesteps (default=200)'
    )
    parser.add_argument(
        '-lr', '--learning-rate',
        type=float, default=0.001,
        help='optimizer learning rate (default=0.001)'
    )
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def print_parameters(args):
    if args.verbose:
        import tfmdp
        print()
        print('Running tf-mdp v{} ...'.format(tfmdp.__version__))
        print()
        print('>> RDDL:   {}'.format(args.rddl))
        print('>> logdir: {}'.format(args.logdir))
        print()
        print('>> Policy Net:')
        print('layers    = [{}]'.format(','.join(args.layers)))
        print('layernorm = {}'.format(args.layer_norm))
        print()
        print('>> Training:')
        print('epochs        = {}'.format(args.epochs))
        print('learning rate = {}'.format(args.learning_rate))
        print('batch size    = {}'.format(args.batch_size))
        print('horizon       = {}'.format(args.horizon))
        print()


def load_model(args):
    compiler = rddlgym.make(args.rddl, mode=rddlgym.SCG)
    compiler.batch_mode_on()
    return compiler


def solve(compiler, args):
    from tfmdp.planner import PolicyOptimizationPlanner
    planner = PolicyOptimizationPlanner(compiler, args.layers, args.layer_norm, args.logdir)
    planner.build(args.learning_rate, args.batch_size, args.horizon)
    rewards, policy, logdir = planner.run(args.epochs)
    return rewards, policy, logdir


def report_performance(rewards, horizon):
    reward = rewards[-1][1]
    print('>> Performance:')
    print('total reward = {:.4f}, reward per timestep = {:.4f}\n'.format(reward, reward / horizon))


if __name__ == '__main__':

    args = parse_args()
    print_parameters(args)

    print('>> Loading model ...')
    start = time.time()
    compiler = load_model(args)
    end = time.time()
    uptime = end - start
    print('Done in {:.6f} sec.'.format(uptime))
    print()

    print('>> Optimizing...')
    start = time.time()
    rewards, policy, logdir = solve(compiler, args)
    end = time.time()
    uptime = end - start
    print()
    print('Done in {:.6f} sec.'.format(uptime))
    print()

    report_performance(rewards, args.horizon)

    print('tensorboard --logdir {}\n'.format(logdir))
