#!/usr/bin/env python3

# This file is part of pyrddl.

# pyrddl is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# pyrddl is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with pyrddl. If not, see <http://www.gnu.org/licenses/>.


from pyrddl.parser import RDDLParser

import argparse


def parse_args():
    description = 'RDDL lexer/parser in Python3.'
    parser = argparse.ArgumentParser(description=description)
    parser.add_argument('rddl', type=str, help='RDDL filepath')
    parser.add_argument(
        '-v', '--verbose',
        action='store_true',
        help='verbosity mode'
    )
    return parser.parse_args()


def read_file(path):
    with open(path, 'r') as f:
        return f.read()


def print_block_header(block, name):
    print('*********************************************************')
    print(' RDDL {} = {}'.format(block, name))
    print('*********************************************************')


def print_types(param_types):
    print('>> types: {}'.format(','.join(param_types)))


def print_pvariables(pvar_type, pvariables):
    print('>> {}:'.format(pvar_type))
    for name, pvar in pvariables.items():
        if pvar.is_intermediate_fluent():
            print('    - {} : range={}, level={}'.format(name, pvar.range, pvar.level))
        else:
            print('    - {} : range={}, default={}'.format(name, pvar.range, pvar.default))
    print()


def print_cpfs(fluent_type, cpfs):
    print('>> {}:'.format(fluent_type))
    for cpf in cpfs:
        name = cpf.pvar[1][0]
        args = '' if cpf.pvar[1][1] is None else ', '.join(cpf.pvar[1][1])
        print('{}({}) = '.format(name, args))
        print(cpf.expr)
        print()


def print_reward(r):
    print('>> reward:')
    print()
    print(r)
    print()


def print_constraints(constraint_type, constraints):
    print('>> {}:'.format(constraint_type))
    print()
    for c in constraints:
        print(c)
        print()


def print_initializer(initializer_type, initializers):
    print('>> {}:'.format(initializer_type))
    for initializer in initializers:
        name = initializer[0][0]
        args = '' if initializer[0][1] is None else ', '.join(initializer[0][1])
        value = initializer[1]
        print('{}({}) = {}'.format(name, args, value))


def print_domain(d, verbose=False):
    print_block_header('domain', d.name)
    print()

    if verbose:
        # types
        if d.types:
            print_types([t[0] for t in d.types])

        # pvariables
        print_pvariables('non-fluents', d.non_fluents)
        print_pvariables('state-fluents', d.state_fluents)
        print_pvariables('intermediate-fluents', d.intermediate_fluents)
        print_pvariables('action-fluents', d.action_fluents)

        # cpfs
        print_cpfs('state_cpfs', d.state_cpfs)
        print_cpfs('intermediate_cpfs', d.intermediate_cpfs)

        # reward
        print_reward(d.reward)

        # constraints
        if d.preconds:
            print_constraints('preconds', d.preconds)
        if d.constraints:
            print_constraints('constraints', d.constraints)
        if d.invariants:
            print_constraints('invariants', d.invariants)


def print_non_fluents(nf, verbose=False):
    print_block_header('non-fluents', nf.name)

    if verbose:

        # objects
        print('>> objects:')
        for t, objects in nf.objects:
            print('     - {}: {}'.format(t, ','.join(objects)))
        print()

        # non-fluent initializers
        print_initializer('non-fluents', nf.init_non_fluent)

    print()


def print_instance(inst, verbose=False):
    print_block_header('instance', inst.name)

    if verbose:
        print('horizon = {}'.format(inst.horizon))
        print('max_nondef_actions = {}'.format(inst.max_nondef_actions))
        print('discount = {}'.format(inst.discount))
        print()

        print_initializer('init-state', inst.init_state)

    print()


if __name__ == '__main__':

    # parser CLI arguments
    args = parse_args()
    verbose = args.verbose

    # read RDDL file
    rddl = read_file(args.rddl)

    # buid parser
    parser = RDDLParser()
    parser.build()

    # parse RDDL
    rddl = parser.parse(rddl)

    # print RDDL
    print_domain(rddl.domain, verbose)
    print_non_fluents(rddl.non_fluents, verbose)
    print_instance(rddl.instance, verbose)
