#!/usr/bin/env python3
import argparse
from collections import (
    OrderedDict,
)
import json
import os
import sys
from typing import (
    Any,
    Iterable,
    Iterator,
    Set,
    TypeVar,
)
import warnings

import vyper
from vyper.parser import (
    parser_utils,
)
from vyper.signatures.interface import (
    extract_file_interface_imports,
)
from vyper.typing import (
    ContractCodes,
    ContractName,
)

warnings.simplefilter('always')

format_options_help = """Format to print, one or more of:
 bytecode (default) - Deployable bytecode
 bytecode_runtime   - Bytecode at runtime
 abi                - ABI in JSON format
 abi_python         - ABI in python format
 ast                - AST in JSON format
 source_map         - Vyper source map
 method_identifiers - Dictionary of method signature to method identifier.
 combined_json      - All of the above format options combined as single JSON output.
 interface          - Print Vyper interface of a contract
 external_interface - Print Externa Contract of a contract, to be used as outside contract calls.
 opcodes            - List of opcodes as a string
 opcodes_runtime    - List of runtime opcodes as a string
"""

parser = argparse.ArgumentParser(
    description='Vyper programming language for Ethereum',
    formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
    'input_files',
    help='Vyper sourcecode to compile',
    nargs='+',
)
parser.add_argument(
    '--version',
    action='version',
    version='{0}+commit.{1}'.format(vyper.__version__, vyper.__commit__),
)
parser.add_argument(
    '--show-gas-estimates',
    help='Show gas estimates in ir output mode.',
    action="store_true",
)
parser.add_argument(
    '-f',
    help=format_options_help,
    default='bytecode', dest='format',
)
parser.add_argument(
    '--debug',
    help='Add compiler debugging output',
    action='store_true',
)

args = parser.parse_args()

tb_limit = os.environ.get('VYPER_TRACEBACK_LIMIT')
if tb_limit:
    sys.tracebacklimit = int(tb_limit)
elif not args.debug:
    # sys.traceback limit defaults to 1000
    sys.tracebacklimit = 0


T = TypeVar('T')


def uniq(seq: Iterable[T]) -> Iterator[T]:
    """
    Yield unique items in ``seq`` in order.
    """
    seen: Set[T] = set()

    for x in seq:
        if x in seen:
            continue

        seen.add(x)
        yield x


def exc_handler(contract_name: ContractName, exception: Exception) -> None:
    print('Error compiling: ', contract_name)
    raise exception


def get_interface_codes(codes: ContractCodes) -> Any:
    interface_codes = {}
    interfaces = {}

    for code in codes.values():
        interface_codes.update(extract_file_interface_imports(code))

    if interface_codes:
        for interface_name, interface_path in interface_codes.items():
            file_path = os.path.join(os.path.normpath(interface_path.replace('.', '/')))
            extensions = ('.vy', '.json')
            valid_paths = [
                file_path + ex
                for ex in extensions if os.path.exists(file_path + ex)
            ]
            if not valid_paths:
                raise Exception(
                    f'Imported interface "{interface_path}{{.vy,.json}}" does not exist.'
                )

            valid_path = valid_paths[0]
            with open(valid_path) as fh:
                code = fh.read()
                if valid_path.endswith('.json'):
                    interfaces[interface_name] = {
                        'type': 'json',
                        'code': json.loads(code.encode())
                    }
                else:
                    interfaces[interface_name] = {
                        'type': 'vyper',
                        'code': code
                    }

    return interfaces


if __name__ == '__main__':
    if args.show_gas_estimates:
        parser_utils.LLLnode.repr_show_gas = True

    codes: ContractCodes = OrderedDict()
    for file_name in args.input_files:
        with open(file_name) as fh:
            codes[file_name] = fh.read()

    # Combined json output
    if args.format == 'combined_json':
        out = vyper.compile_codes(
            codes,
            ['bytecode', 'bytecode_runtime', 'abi', 'source_map', 'method_identifiers'],
            exc_handler=exc_handler,
            interface_codes=get_interface_codes(codes)
        )
        out['version'] = vyper.__version__
        print(json.dumps(out))

    else:  # Normal output.
        translate_map = {
            'abi_python': 'abi',
            'json': 'abi',
            'ast': 'ast_dict'
        }
        formats = []
        orig_args = tuple(uniq(args.format.split(',')))

        for f in orig_args:
            formats.append(translate_map.get(f, f))

        out_list = list(vyper.compile_codes(
            codes,
            formats,
            exc_handler=exc_handler,
            interface_codes=get_interface_codes(codes)
        ).values())

        for out in out_list:
            for f in orig_args:
                o = out[translate_map.get(f, f)]
                if f in ('abi', 'json', 'ast'):
                    print(json.dumps(o))
                elif f == 'abi_python':
                    print(o)
                elif f == 'source_map':
                    print(json.dumps(o))
                else:
                    print(o)
