#!/usr/bin/env python3

from itertools import permutations, combinations
import random
from pathlib import Path
import logging
import click
import time
import sys

import qsynthesis
from qsynthesis import TritonGrammar, BvOp, HashType, InputOutputOracleLevelDB


def biased_input_generator(bitsize: int, var_num: int, input_number: int, bs: int, random_level: int = 2):
    n = max(var_num-3, random_level)
    minus_one = pow(2, bitsize)-1
    vals = [1, 0, minus_one] + [None] * n
    all_perms = list(permutations(vals, var_num))
    return list(map(lambda l: [random.getrandbits(bs) if x is None else x for x in l], random.sample(all_perms, k=input_number)))


def operator_generator(nb_operator: int):
    """ Generate infinite operators set (making sure all combinations  """
    ops = [BvOp.NOT, BvOp.AND, BvOp.OR, BvOp.XOR, BvOp.NEG, BvOp.ADD, BvOp.MUL, BvOp.SUB]
    while 1:
        l = list(combinations(ops, nb_operator))
        random.shuffle(l)
        yield from l


@click.group(context_settings={'help_option_names': ['-h', '--help']})
@click.version_option(version=qsynthesis.__version__, message='%(version)s')
def main():
    pass


@main.command(name="generate")
@click.argument('output_file', type=str)
@click.option('-l', '--limit', default=0, type=int, help="Limit number of expressions to generate (-1 no limit)")
@click.option('-bs', '--bitsize', metavar="bitsize", default=64, type=int, help="Bit size of expressions")
@click.option('--var-num', default=3, type=int, help="Number of variables")
@click.option('--input-num', default=5, type=int, help="Number of inputs")
@click.option('--random-level', type=int, default=2, help="Randomness level of inputs 0 means higlhly biased to use corner-case values (0,1,-1)")
@click.option('--op-num', default=5, type=int, help="Operator number")
@click.option("-v", "--verbosity", default=0, count=True, help="increase output verbosity")
@click.option('--ops', type=str, default='', help='specifying operators to uses')
@click.option('--inputs', type=str, default='', help='specifying input vector to use')
@click.option('--hash-mode', default=HashType.MD5.name, type=click.Choice([x.name for x in HashType]), help="Hash function for keys in table")
@click.option('--watchdog', type=float, help="Activate RAM watchdog (percentage of load when to stop)")
@click.option('-c', '--cst', type=str, help="Constant to add in the generation process", multiple=True)
@click.option('--linearization', is_flag=True, type=bool, default=False, help="If set activate linearization of expressions")
def generate_command(output_file, limit, bitsize, var_num, input_num, random_level, op_num, verbosity, ops, inputs, hash_mode, watchdog, cst, linearization):
    """ Table generation utility """
    logging.basicConfig(level=logging.DEBUG if verbosity else logging.INFO, format='%(message)s')

    constants = [int(x, 16 if x.startswith("0x") else 10) for x in cst]

    try:
        import pydffi
    except ImportError:
        raise click.Abort("Cannot import dragonffi (pip3 install pydffi")
    try:
        import sympy
    except ImportError:
        raise click.Abort("Cannot import sympy (pip3 install sympy")

    if bitsize not in [8, 16, 32, 64]:
        print(f"Invalid bitsize {bitsize} valid ones [8, 16, 32, 64]")
        sys.exit(1)

    out_dir = Path(output_file)
    if out_dir.exists() and out_dir.is_file():
        out_dir.unlink()

    ops = [BvOp[x] for x in ops.split(",")] if ops else None
    inputs = [int(x) for x in inputs.split(",") if x]

    t1 = time.time()

    logging.info(f"Generate Table")

    operators = next(operator_generator(op_num)) if ops is None else ops
    vrs = [chr(ord('a') + x) for x in range(var_num)]
    if inputs:
        inputs = [{n: v for n, v in zip(vrs, inputs[i:i + len(vrs)])} for i in
                  range(0, len(inputs), len(vrs))]
    else:
        inputs = biased_input_generator(bitsize, var_num, input_num, bitsize, random_level)
        inputs = [{n: v for n, v in zip(vrs, i)} for i in inputs]

    grammar = TritonGrammar([(x, bitsize) for x in vrs], operators)

    logging.info(f"Watchdog value: {watchdog}")
    ltm = InputOutputOracleLevelDB.create(out_dir.absolute(), grammar, inputs, HashType[hash_mode], constants)
    try:
        if watchdog:
            ltm.generate(bitsize, constants=constants, do_watch=True, watchdog_threshold=watchdog, linearize=linearization, limit=limit)
        else:
            ltm.generate(bitsize, constants=constants, linearize=linearization, limit=limit)
    except KeyboardInterrupt:
        logging.warning("Stop required")

    elapsed = time.time() - t1
    hours, rem = divmod(elapsed, 3600)
    minutes, seconds = divmod(rem, 60)
    logging.info(f"\n{int(hours)}h{int(minutes)}m{seconds:.2f}s")


@main.command(name="info")
@click.argument('table_file', type=click.Path(exists=True))
def infos_command(table_file):
    """Getting information of a given database"""
    logging.basicConfig(level=logging.INFO, format='%(message)s')
    table_file = Path(table_file)

    table = InputOutputOracleLevelDB.load(table_file)

    logging.info(f"Bitsize: {table.bitsize}")
    logging.info(f"Hash mode: {table.hash_mode.name}")
    logging.info(f"Size: {table.size}")
    logging.info(f"Variables: {table.grammar.vars}")
    logging.info(f"Operators: {[x.name for x in table.grammar.ops]}")
    logging.info(f"Nb inputs: {len(table.inputs)}")
    l = []
    for i in table.inputs:
        for v in i.values():
            l.append(v)
    logging.info(",".join(str(x) for x in l))


@main.command(name="check")
@click.argument('table_file', type=click.Path(exists=True))
def check_command(table_file):
    """Checking the equivalence of hashes against evaluation of expressions on inputs"""
    logging.basicConfig(level=logging.INFO, format='%(message)s')
    table_file = Path(table_file)

    table = InputOutputOracleLevelDB.load(table_file)
    count = table.size
    good, bad = 0, 0

    for i, (h, expr) in enumerate(table):
        if i % 100 == 0:
            print(f"process {i}/{count} [KO:{bad}]\r", end="")
        triton_exp = table.get_expr(expr)
        outs = table.eval_expr_inputs(triton_exp)
        if table.hash(outs) != h:
            logging.warning(f"Bad expression: {expr}  with [{outs}]")
        else:
            good += 1
    logging.info(f"[OK:{good}/{count}]{'': <15}")


@main.command(name="compare")
@click.argument('table1', type=click.Path(exists=True))
@click.argument('table2', type=click.Path(exists=True))
def compare_command(table1, table2):
    """Compare two tables"""
    table1 = InputOutputOracleLevelDB.load(table1)
    table2 = InputOutputOracleLevelDB.load(table2)

    only1 = 0
    only2 = 0
    common = 0
    sz1 = table1.size
    sz2 = table2.size
    for h, k in table1:
        if table2.db.get(h):
            common += 1
        else:
            only1 += 1
    for h, k in table2:
        if not table1.db.get(h):
            only2 += 1

    print(f"Table 1 size:{sz1}\tTable 2 size:{sz2}\t[Inputs:{'OK' if table1.inputs == table2.inputs else 'DIFFERENT'}]")
    print(f"Only table 1:{only1}\tOnly table2:{only2}\tCommons:{common}")
    # FUTURE: Implementing semantic comparison of common keys


@main.command(name="merge")
@click.argument('in_table', type=click.Path(exists=True))
@click.argument('out_table', type=click.Path(exists=False))
def merge_command(in_table, out_table):
    """Merge entries of the first database in the second"""
    lkp_in = InputOutputOracleLevelDB.load(in_table)
    lkp_out = InputOutputOracleLevelDB.load(out_table)

    if lkp_in.inputs != lkp_out.inputs:
        print("Tables should use the same set of inputs")
        sys.exit(1)

    i = 0
    c = 0
    sz = lkp_in.size
    for hash, s in lkp_in:
        if lkp_out.db.get(hash) is None:
            lkp_out.add_entry(hash, s)
            i += 1
        c += 1
        if c % 100 == 0:
            print(f"count:{c}/{sz} (imported:{i})\r", end="")

    print(f"Imported: {i}")


@main.command(name="dump")
@click.option('-l', '--limit', type=int, default=0, help='maximum number of entries to dump')
@click.argument('in_table', type=click.Path(exists=True))
def dump_command(limit, in_table):
    """Dump the content of the table on stdout"""
    lkp_in = InputOutputOracleLevelDB.load(in_table)

    counter = 0
    for hash, s in lkp_in:
        if counter > limit > 0:
            break
        print(f"{hash} -> {s}")
        counter += 1


if __name__ == "__main__":
    main()
