#!/usr/bin/env python3

from web3 import Web3, HTTPProvider, WebsocketProvider, IPCProvider, datastructures, exceptions
from web3.middleware import geth_poa_middleware
from hexbytes import HexBytes
import argparse
import json
import logging
import os
import traceback
import getpass

parser = argparse.ArgumentParser(description='Interact with ethereum network.',
                                 fromfile_prefix_chars='@')
parser.add_argument('command', nargs='?',
                    help='help version exec balance nonce deploy call send receipt ...'),
parser.add_argument('arguments', nargs='*',
                    help='optional arguments to the command')
parser.add_argument('-a', '--address',
                    help='contract address to send/call, defaults to env '
                    'WEB3_CONTRACT_{CONTRACT_NAME}')
parser.add_argument('-b', '--block',
                    help='defaultBlock parameter: block number or "earliest|latest|pending"')
parser.add_argument('-c', '--contract',
                    help='contract name or its json/abi path')
parser.add_argument('-d', '--dry', action='store_true',
                    help='dry run, do not transact')
parser.add_argument('-f', '--from', dest='from_account',
                    default=os.environ.get('WEB3_FROM', None),
                    help='account keystore filename, raw private key or account address, '
                    'defaults to env WEB3_FROM')
parser.add_argument('-j', '--contractJson',
                    help='contract json/abi path if different from contract name')
parser.add_argument('-p', '--provider',
                    default=os.environ.get('WEB3_PROVIDER', 'http://localhost:8545'),
                    help='web3 provider, defaults to env WEB3_PROVIDER or localhost:8545')
parser.add_argument('-t', '--to',
                    help='account address to transact to')
parser.add_argument('--chainId', type=int,
                    help='explicitly set chain id')
parser.add_argument('--password', type=argparse.FileType('r'),
                    help='pass phrase to unlock the from account, defaults to empty')
parser.add_argument('--gasPrice', type=int,
                    help='explicitly set gas price')
parser.add_argument('--nonce',
                    help='explicitly set nonce')
parser.add_argument('--timeout', type=int, default=120,
                    help='timeout to wait after the transaction, default is 120s')
parser.add_argument('--value', type=int,
                    help='money amount to use in transaction')
parser.add_argument('-V', '--version', action='store_true',
                    help='show version and exit')
parser.add_argument('--poa', action='store_true',
                    help='activate web3 middleware for POA networks')

group = parser.add_mutually_exclusive_group()
group.add_argument('-v', '--verbose', action='count',
                   help='verbose output')
group.add_argument('-q', '--quiet', action='count',
                   help='quiet output')

group = parser.add_mutually_exclusive_group()
group.add_argument('-e', '--estimate', '--estimateGas', action='store_true',
                   help='estimate gas for the tx and use it unless dry run')
group.add_argument('--gas', type=int,
                   help='explicitly set gas amount')

args = parser.parse_args()


def getLogger(defaultLevel):
    log = logging.getLogger(__name__)
    if args.verbose:
        defaultLevel -= 10*args.verbose
    if args.quiet:
        defaultLevel += 10*args.quiet
    log.setLevel(max(defaultLevel, logging.DEBUG))
    stream = logging.StreamHandler()
    stream.setFormatter(logging.Formatter('%(levelname)s %(message)s'))
    log.addHandler(stream)
    return log


log = getLogger(logging.INFO)


def command(*names):
    def decorator(func):
        for n in names:
            dispatch[n] = func
        function_names[names[0]] = names
        return func
    return decorator


dispatch = {}
function_names = {}


def getJson(fname):
    with open(fname) as file:
        return json.load(file)


def getOptionalArguments(ord):
    if len(args.arguments) > ord:
        log.debug('Raw args %s', args.arguments[ord:])
        arguments = eval('[' + ','.join(args.arguments[ord:]) + ']')
        log.debug('Arguments: %s', arguments)
        return arguments
    return []


class Dymka:
    def __init__(self):
        log.info(f'Using {args.provider} to connect web3...')
        if args.provider.startswith('ws'):
            self.provider = WebsocketProvider(args.provider)
        elif args.provider.startswith('ipc'):
            self.provider = IPCProvider(args.provider)
        else:
            self.provider = HTTPProvider(args.provider)
        self.w3 = Web3(self.provider)
        if args.poa:
            # inject the poa compatibility middleware to the innermost layer
            self.w3.middleware_onion.inject(geth_poa_middleware, layer=0)
        self.privKey = None
        self.address_from = None
        if args.from_account:
            if os.path.isfile(args.from_account):
                password = args.password.read() if args.password else ''
                with open(args.from_account) as keyfile:
                    encrypted_key = keyfile.read()
                    try:
                        self.privKey = self.w3.eth.account.decrypt(encrypted_key, password)
                    except ValueError:
                        password = getpass.getpass()
                        self.privKey = self.w3.eth.account.decrypt(encrypted_key, password)
            else:
                self.privKey = args.from_account
        if self.privKey:
            try:
                self.address_from = self.w3.eth.account.from_key(self.privKey).address
            except ValueError:
                log.info(f'Using from address {args.from_account}, no private key.')
                self.address_from = args.from_account
                self.privKey = None

    @command('show', 'config')
    def show(self):
        """Shows configuration used."""
        status = {'provider': args.provider}
        if self.address_from:
            status['address'] = self.address_from
        return status

    @command('gas', 'gasprice')
    def gas(self):
        """Returns current gas price."""
        return {'gasPrice': self.w3.eth.gas_price}

    @command('private', 'priv')
    def private(self):
        """Returns private key of the 'from' account."""
        return {'address': self.address_from, 'private_key': self.privKey}

    @staticmethod
    def processAccounts(fn, *vargs):
        """Applies functor `fn` to program arguments and function vargs."""
        return list(map(lambda a: a and {'account': a, 'result': fn(a)},
                        [*args.arguments, *filter(None, vargs)]))

    @command('checksum', 'cs')
    def checksum(self):
        """Calculate address string representation with correct web3 checksum."""
        return self.processAccounts(Web3.to_checksum_address, self.address_from, args.to)

    @command('balance', 'bal')
    def balance(self):
        """Returns balance of all provided arguments and --from account."""
        return self.processAccounts(self.w3.eth.get_balance, self.address_from, args.to)

    @command('nonce')
    def nonce(self):
        """Returns nonce for all provided arguments and --from account."""
        return self.processAccounts(self.w3.eth.get_transaction_count,
                                    self.address_from, args.to)

    @command('chain', 'chain_id', 'id')
    def chain(self):
        """Returns chain id of the given network."""
        return self.w3.eth.chain_id

    @command('accounts', 'account', 'acc')
    def accounts(self):
        """List available accounts."""
        return self.w3.eth.accounts

    @command('keccak', 'keccak256')
    def keccak(self):
        """Calculate Keccak-256 of the given string argument."""
        return Web3.keccak(text=args.arguments[0])

    @command('execute', 'exec', 'exe')
    def execute(self):
        """Executes given JSON-RPC command."""
        return self.provider.make_request(args.arguments[0], *[getOptionalArguments(1)])

    @command('transaction', 'tx')
    def transaction(self):
        """Returns transaction with the given hash."""
        hash = args.arguments[0]
        tx = self.w3.eth.get_transaction(hash)
        return dict(tx)

    @command('block', 'bl')
    def block(self):
        """Returns current block number or specified block content."""
        if args.arguments:
            return self.w3.eth.get_block(getOptionalArguments(0)[0])
        return self.w3.eth.block_number

    @staticmethod
    def getContractAddressEnv(name):
        return f'WEB3_CONTRACT_{name.upper()}'

    @classmethod
    def getContractAddress(cls):
        """Discovers contract address from --address or --contract arguments."""
        if args.address:
            log.debug(f'Using contract at {args.address}.')
            return args.address
        if args.contract:
            envname = cls.getContractAddressEnv(args.contract)
            address = os.environ.get(envname, None)
            if address:
                log.info(f'Using contract {args.contract} at {address}.')
                return address
            raise ValueError(f'Need contract address, but neither --address nor '
                             f'{envname} is set.')
        raise ValueError('Need contract address, but neither --address nor --contract '
                         'is set.')

    def getContract(self, **kwargs):
        """
        Returns contract object discovered from --contract and --contractJson
        arguments and guess what format the given json is.

        """
        fname = args.contractJson
        if not fname:
            fname = args.contract + '.json'
            if not os.path.isfile(fname):
                fname = args.contract
        data = getJson(fname)

        # First guess, data is a list - it is ABI.
        if isinstance(data, list):
            log.debug(f'Using contract ABI from {fname}.')
            return self.w3.eth.contract(abi=data, **kwargs)

        # Second guess, data has 'contractName', 'abi' and possibly 'bytecode' keys.
        if 'abi' in data:
            log.debug(f'Using simple-json contract from {fname}.')
            if 'contractName' in data:
                contract_name = data['contractName']
                if args.contract and args.contract != contract_name:
                    raise ValueError(f'Contract name "{contract_name}" does not match '
                                     f'to "{args.contract}" that you specified.')
            if 'bytecode' in data:
                return self.w3.eth.contract(abi=data['abi'],
                                            bytecode=data['bytecode'],
                                            **kwargs)
            return self.w3.eth.contract(abi=data['abi'], **kwargs)

        # Last guess, this is "combined-json" solc format.
        log.debug(f'Using combined-json from {fname}.')
        contracts_section = data['contracts']
        contracts = [key for key in contracts_section.keys()
                     if key.lower().endswith(args.contract.lower())]
        if len(contracts) != 1:
            raise ValueError(f'Ambiguous or empty contract list {contracts} in json '
                             f'{fname} for contract {args.contract}.')
        contract = contracts_section[contracts[0]]
        return self.w3.eth.contract(abi=contract['abi'],
                                    bytecode=contract['bin'],
                                    **kwargs)

    @command('selector', 'sel')
    def functions(self):
        """Returns contract's function by given selector"""
        contract = self.getContract()
        if args.arguments:
            return str(contract.get_function_by_selector(args.arguments[0]))
        return list(map(lambda fn: str(fn), contract.all_functions()))

    def processLogs(self, raw_logs):
        """Process given raw logs using contract ABI."""
        if not args.contract and not args.contractJson:
            log.warning('Unable to parse contract events, '
                        'specify --contract or --contractJson for detailed events.')
            return raw_logs
        contract = self.getContract()
        logs = []
        for event in raw_logs:
            processed = False
            for evt in contract.events:
                try:
                    logs.append(contract.events[evt.event_name]().process_log(event))
                    processed = True
                    break
                except exceptions.MismatchedABI:
                    pass
            if not processed:
                logs.append(event)
        return logs

    def getTxReceipt(self, hash):
        receipt = dict(self.w3.eth.get_transaction_receipt(hash))
        receipt['logs'] = self.processLogs(receipt['logs'])
        return receipt

    @command('receipt')
    def receipt(self):
        """Returns transaction receipt with the given hash."""
        return self.getTxReceipt(args.arguments[0])

    @command('events', 'logs', 'log')
    def events(self):
        """
        Returns emitted events of the given contract for the specified block
        range (last 1000 blocks by default)
        """
        if args.arguments:
            range = args.arguments[0].split('-')
            if len(range) == 1:
                toBlock = int(range[0])
                fromBlock = int(range[0])
            elif len(range) == 2 and range[0]:
                toBlock = int(range[1]) if range[1] else self.w3.eth.block_number
                fromBlock = int(range[0])
            else:
                toBlock = self.w3.eth.block_number
                fromBlock = toBlock - int(range[1])
        else:
            toBlock = self.w3.eth.block_number
            fromBlock = max(0, toBlock - 1000)
        log.info(f'Requesting logs in range from {fromBlock} to {toBlock}...')
        return self.processLogs(
            self.w3.eth.get_logs({
                'address': self.getContractAddress(),
                'fromBlock': fromBlock,
                'toBlock': toBlock}))

    @command('call', 'query')
    def call(self):
        """
        Call the specified contract function with the specified parameters, if
        any.  Note, that this is not a transaction, but just a query. No
        blockchain data may change.

        """
        function_name = args.arguments[0]
        contract = self.getContract(address=self.getContractAddress())
        func = contract.functions[function_name]
        tx = func(*getOptionalArguments(1))
        from_account = self.address_from
        if not from_account:
            from_account = '0x0000000000000000000000000000000000000000'
            log.warning(f'From account not specified, using {from_account}.')
        return {'result': tx.call({'from': from_account})}

    def getOpts(self):
        """Prepare options for making a transaction."""
        opts = {}
        if self.address_from:
            opts['from'] = self.address_from
        if args.nonce or self.address_from:
            opts['nonce'] = args.nonce or self.w3.eth.get_transaction_count(self.address_from)
        for name in ['chainId', 'gas', 'gasPrice', 'value', 'to']:
            if vars(args)[name]:
                opts[name] = vars(args)[name]
        log.info('Transaction options: %s', opts)
        return opts

    def transact(self, tx):
        status = {}
        if args.estimate:
            gas = self.w3.eth.estimate_gas(tx)
            log.info('Gas estimated %s', gas)
            tx['gas'] = status['gas'] = gas
        if args.dry:
            log.info('Dry run requested, nothing more to do')
            return status
        hash = ''
        if self.privKey:
            signed = self.w3.eth.account.sign_transaction(tx, private_key=self.privKey)
            hash = self.w3.eth.send_raw_transaction(signed.rawTransaction)
        else:
            hash = self.w3.eth.send_transaction(tx)
        status['hash'] = hash.hex()
        log.info('Transaction hash: %s', hash.hex())
        try:
            self.w3.eth.wait_for_transaction_receipt(hash, timeout=args.timeout)
        except Exception:
            raise ValueError(f'Timeout waiting {args.timeout}s for transaction {hash.hex()}')
        status['receipt'] = self.getTxReceipt(hash)
        return status

    @command('deploy')
    def deploy(self):
        """
        Deploy the given contract. Pass arguments to its constructor, if
        needed. Note that full contract data JSON must be provided, not just
        ABI.

        """
        contract = self.getContract()
        ctor = contract.constructor(*getOptionalArguments(0))
        tx = ctor.build_transaction(self.getOpts())
        status = self.transact(tx)
        if 'receipt' in status:
            address = status['receipt']['contractAddress']
            log.info(f'Evaluate: '
                     f'export {self.getContractAddressEnv(args.contract)}={address}')
        return status

    def buildContractSendTx(self, opts):
        function_name = args.arguments[0]
        contract = self.getContract(address=self.getContractAddress())
        func = contract.functions[function_name]
        func_bound = func(*getOptionalArguments(1))
        return func_bound.build_transaction(opts)

    @command('send')
    def send(self):
        """
        Invoke the specified function for the specified contract. Pass arguments,
        if needed.
        """
        opts = self.getOpts()
        tx = opts if not args.contract else self.buildContractSendTx(opts)
        return self.transact(tx)


class EthereumEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, HexBytes) or isinstance(obj, bytes):
            return obj.hex()
        elif isinstance(obj, datastructures.AttributeDict):
            return dict(obj)
        return json.JSONEncoder.default(self, obj)


if __name__ == "__main__":
    if not args.command and not args.version:
        print('Please specify a command or "help".')
        exit(0)
    elif args.version or args.command == 'version':
        print(f'Version: 1.2.1, Web3.py: {str(Web3().api)}.')
        exit(0)
    elif args.command == 'help':
        if args.arguments:
            print(dispatch[args.arguments[0].lower()].__doc__)
        else:
            print('Specify a command for help:')
            for name in function_names.keys():
                print('\t' + ', '.join(function_names[name]))
        exit(0)
    try:
        d = Dymka()
        result = dispatch[args.command.lower()](d)
        print(json.dumps(result, cls=EthereumEncoder, indent=4))
    except Exception as e:
        if args.verbose and args.verbose > 1:
            traceback.print_exc()
        else:
            log.error(e)
        exit(1)
