#!python

from itertools import permutations, combinations
import random
from pathlib import Path
import logging
import click
import time
import sys

import qsynthesis
from qsynthesis import TritonGrammar, BvOp, InputOutputOracleLevelDB


def biased_input_generator(bitsize: int, var_num: int, input_number: int, bs: int, random_level: int = 2):
    n = max(var_num-3, random_level)
    minus_one = pow(2, bitsize)-1
    vals = [1, 0, minus_one] + [None] * n
    all_perms = list(permutations(vals, var_num))
    return list(map(lambda l: [random.getrandbits(bs) if x is None else x for x in l], random.sample(all_perms, k=input_number)))


def operator_generator(nb_operator: int):
    """ Generate infinite operators set (making sure all combinations  """
    ops = [BvOp.NOT, BvOp.AND, BvOp.OR, BvOp.XOR, BvOp.NEG, BvOp.ADD, BvOp.MUL, BvOp.SUB]
    while 1:
        l = list(combinations(ops, nb_operator))
        random.shuffle(l)
        yield from l


@click.group(context_settings={'help_option_names': ['-h', '--help']})
@click.version_option(version=qsynthesis.__version__, message='%(version)s')
def main():
    pass


@main.command(name="generate")
@click.argument('output_file', type=str)
@click.option('-l', '--limit', default=0, type=int, help="Limit number of expressions to generate (-1 no limit)")
@click.option('-bs', '--bitsize', metavar="bitsize", default=64, type=int, help="Bit size of expressions")
@click.option('--var-num', default=3, type=int, help="Number of variables")
@click.option('--input-num', default=5, type=int, help="Number of inputs")
@click.option('--random-level', type=int, default=2, help="Randomness level of inputs 0 means higlhly biased to use corner-case values (0,1,-1)")
@click.option('--op-num', default=5, type=int, help="Operator number")
@click.option("-v", "--verbosity", default=0, count=True, help="increase output verbosity")
@click.option('--ops', type=str, default='', help='specifying operators to uses')
@click.option('--inputs', type=str, default='', help='specifying input vector to use')
@click.option('--watchdog', type=float, help="Activate RAM watchdog (percentage of load when to stop)")
@click.option('-c', '--cst', type=str, help="Constant to add in the generation process", multiple=True)
@click.option('--linearization', is_flag=True, type=bool, default=False, help="If set activate linearization of expressions")
def generate_command(output_file, limit, bitsize, var_num, input_num, random_level, op_num, verbosity, ops, inputs, watchdog, cst, linearization):
    """ Table generation utility """
    logging.basicConfig(level=logging.DEBUG if verbosity else logging.INFO, format='%(message)s')

    constants = [int(x, 16 if x.startswith("0x") else 10) for x in cst]

    try:
        import pydffi
    except ImportError:
        raise click.Abort("Cannot import dragonffi (pip3 install pydffi")
    try:
        import sympy
    except ImportError:
        raise click.Abort("Cannot import sympy (pip3 install sympy")

    if bitsize not in [8, 16, 32, 64]:
        print(f"Invalid bitsize {bitsize} valid ones [8, 16, 32, 64]")
        sys.exit(1)

    out_dir = Path(output_file)
    if out_dir.exists() and out_dir.is_file():
        out_dir.unlink()

    ops = [BvOp[x] for x in ops.split(",")] if ops else None
    inputs = [int(x) for x in inputs.split(",") if x]

    t1 = time.time()

    logging.info(f"Generate Table")

    operators = next(operator_generator(op_num)) if ops is None else ops
    vrs = [chr(ord('a') + x) for x in range(var_num)]
    if inputs:
        inputs = [{n: v for n, v in zip(vrs, inputs[i:i + len(vrs)])} for i in
                  range(0, len(inputs), len(vrs))]
    else:
        inputs = biased_input_generator(bitsize, var_num, input_num, bitsize, random_level)
        inputs = [{n: v for n, v in zip(vrs, i)} for i in inputs]

    grammar = TritonGrammar([(x, bitsize) for x in vrs], operators)

    logging.info(f"Watchdog value: {watchdog}")
    ltm = InputOutputOracleLevelDB.create(out_dir.absolute(), grammar, inputs, constants)
    try:
        if watchdog:
            ltm.generate(bitsize, constants=constants, do_watch=True, watchdog_threshold=watchdog, linearize=linearization, limit=limit)
        else:
            ltm.generate(bitsize, constants=constants, linearize=linearization, limit=limit)
    except KeyboardInterrupt:
        logging.warning("Stop required")

    elapsed = time.time() - t1
    hours, rem = divmod(elapsed, 3600)
    minutes, seconds = divmod(rem, 60)
    logging.info(f"\n{int(hours)}h{int(minutes)}m{seconds:.2f}s")


@main.command(name="info")
@click.argument('table_file', type=click.Path(exists=True))
def infos_command(table_file):
    """Getting information of a given database"""
    logging.basicConfig(level=logging.INFO, format='%(message)s')
    table_file = Path(table_file)

    table = InputOutputOracleLevelDB.load(table_file)

    logging.info(f"Bitsize: {table.bitsize}")
    logging.info(f"Size: {table.size}")
    logging.info(f"Variables: {table.grammar.vars}")
    logging.info(f"Operators: {[x.name for x in table.grammar.ops]}")
    logging.info(f"Nb inputs: {len(table.inputs)}")
    l = []
    for i in table.inputs:
        for v in i.values():
            l.append(v)
    logging.info(",".join(str(x) for x in l))


@main.command(name="check")
@click.argument('table_file', type=click.Path(exists=True))
def check_command(table_file):
    """Checking the equivalence of hashes against evaluation of expressions on inputs"""
    logging.basicConfig(level=logging.INFO, format='%(message)s')
    table_file = Path(table_file)

    table = InputOutputOracleLevelDB.load(table_file)
    count = table.size
    good, bad = 0, 0

    for i, (h, expr) in enumerate(table):
        if i % 100 == 0:
            print(f"process {i}/{count} [KO:{bad}]\r", end="")
        triton_exp = table._get_expr(expr)
        outs = table._eval_expr_inputs(triton_exp)
        if table.hash(outs) != h:
            logging.warning(f"Bad expression: {expr}  with [{outs}]")
        else:
            good += 1
    logging.info(f"[OK:{good}/{count}]{'': <15}")


@main.command(name="compare")
@click.argument('table1', type=click.Path(exists=True))
@click.argument('table2', type=click.Path(exists=True))
def compare_command(table1, table2):
    """Compare two tables"""
    table1 = InputOutputOracleLevelDB.load(table1)
    table2 = InputOutputOracleLevelDB.load(table2)

    only1 = 0
    only2 = 0
    common = 0
    sz1 = table1.size
    sz2 = table2.size
    for h, k in table1:
        if table2.db.get(h):
            common += 1
        else:
            only1 += 1
    for h, k in table2:
        if not table1.db.get(h):
            only2 += 1

    print(f"Table 1 size:{sz1}\tTable 2 size:{sz2}\t[Inputs:{'OK' if table1.inputs == table2.inputs else 'DIFFERENT'}]")
    print(f"Only table 1:{only1}\tOnly table2:{only2}\tCommons:{common}")
    # FUTURE: Implementing semantic comparison of common keys


@main.command(name="merge")
@click.argument('in_table', type=click.Path(exists=True))
@click.argument('out_table', type=click.Path(exists=False))
def merge_command(in_table, out_table):
    """Merge entries of the first database in the second"""
    lkp_in = InputOutputOracleLevelDB.load(in_table)
    lkp_out = InputOutputOracleLevelDB.load(out_table)

    if lkp_in.inputs != lkp_out.inputs:
        print("Tables should use the same set of inputs")
        sys.exit(1)

    i = 0
    c = 0
    sz = lkp_in.size
    for hash, s in lkp_in:
        if lkp_out.db.get(hash) is None:
            lkp_out.add_entry(hash, s)
            i += 1
        c += 1
        if c % 100 == 0:
            print(f"count:{c}/{sz} (imported:{i})\r", end="")

    print(f"Imported: {i}")


@main.command(name="dump")
@click.option('-l', '--limit', type=int, default=0, help='maximum number of entries to dump')
@click.argument('in_table', type=click.Path(exists=True))
def dump_command(limit, in_table):
    """Dump the content of the table on stdout"""
    lkp_in = InputOutputOracleLevelDB.load(in_table)

    counter = 0
    for hash, s in lkp_in:
        if counter > limit > 0:
            break
        print(f"{hash} -> {s}")
        counter += 1


if __name__ == "__main__":
    main()
