#! /usr/bin/env python3

# This file is part of MDP-ProbLog.

# MDP-ProbLog is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# MDP-ProbLog is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with MDP-ProbLog.  If not, see <http://www.gnu.org/licenses/>.

import argparse
import os
import time

import mdpproblog
from mdpproblog.mdp import MDP
from mdpproblog.value_iteration import ValueIteration

MODELS = {
	'sysadmin1' : ['sysadmin',       'domain.pl',  'star2.pl'],
	'sysadmin2' : ['sysadmin',       'domain2.pl', 'star2.pl'],
	'vm1'       : ['viralmarketing', 'domain1.pl', 'network1.pl'],
	'vm2'       : ['viralmarketing', 'domain2.pl', 'network1.pl'],
	'vm3'       : ['viralmarketing', 'domain3.pl', 'network2.pl'],
	'vm4'       : ['viralmarketing', 'domain4.pl', 'network2.pl']
}

def parse():
	usage = "usage: mdp-problog {list,display,solve} [-m DOMAIN INSTANCE] [OPTIONS]"
	description = """
		MDP-ProbLog is a framework to represent and solve
		MDPs by Probabilistic Logic Programming.
	    """.strip()

	parser = argparse.ArgumentParser(usage=usage, description=description)
	parser.add_argument("command", choices=['list', 'show', 'solve'], help="available commands.")
	parser.add_argument("-m", "--model", nargs=2, help="list of domain and instance files")
	parser.add_argument("-x", "--example", type=str, help="select model from examples")
	parser.add_argument("-g", "--gamma",   type=float, default=0.9, help="discount factor (default=0.9)")
	parser.add_argument("-e", "--epsilon", type=float, default=0.1, help="maximum error   (default=0.1)")
	return parser.parse_args()

def load_model(domain, instance):
	with open(domain, 'r') as domain:
		domain_model = domain.read()
	with open(instance, 'r') as instance:
		instance_model = instance.read()
	return domain_model, instance_model

def print_list_examples():
	examples = ', '.join(sorted(MODELS.keys()))
	print('>> Available examples: [{}]'.format(examples))

def show_model(domain, instance):
	domain, instance = load_model(domain, instance)
	print(":: DOMAIN ::")
	print(domain)
	print()
	print(":: INSTANCE ::")
	print(instance)
	print()

def solve_model(mdp, gamma, epsilon):
	vi = ValueIteration(mdp)
	return vi.run(gamma, epsilon)

def print_solution(V, policy, iterations, uptime):
	print()
	for state, value in V.items():
		state = ', '.join(["{f}={v}".format(f=f, v=v) for f, v in state])
		print("Value({state}) = {value:.3f}".format(state=state, value=value))
	print()
	for state, action in policy.items():
		state = ', '.join(["{f}={v}".format(f=f, v=v) for f, v in state])
		print("Policy({state}) = {action}".format(state=state, action=action))
	print()
	print(">> Value iteration converged in {time:.3f}sec after {it} iterations.".format(time=uptime, it=iterations))
	print(">> Average time per iteration = {0:.3f}sec.".format(uptime/iterations))

if __name__ == '__main__':
	args = parse()

	if args.command == 'list':
		print_list_examples()
		exit(0)

	if args.example:
		model_dir, domain, instance = MODELS[args.example]
		model_dir = os.path.join(os.path.dirname(mdpproblog.__file__), 'models/', model_dir)
		domain = os.path.join(model_dir, domain)
		instance = os.path.join(model_dir, instance)
	else:
		domain   = args.model[0]
		instance = args.model[1]

	if args.command == 'show':
		show_model(domain, instance)

	if args.command == 'solve':
		domain, instance = load_model(domain, instance)
		model = domain + instance
		mdp = MDP(model)
		start = time.clock()
		V, policy, iterations = solve_model(mdp, args.gamma, args.epsilon)
		end = time.clock()
		uptime = end - start

		print_solution(V, policy, iterations, uptime)
