#!/usr/bin/python3
import lief
import xml.etree.ElementTree as ET

from argparse import ArgumentParser

from internal.asm import Assembler
from internal.ftb import FuncTable
from internal.inj import Inject
from internal.cmpl import Compile
from internal.other import *

parser = ArgumentParser()
parser.add_argument(
    "--xml",
    type=Path,
    dest="config",
    default="hooks.xml",
    help="Hook config file (default: target_dir/hooks.xml)",
)
parser.add_argument("target_dir", type=Path, help="Directory with libs to hook")
parser.add_argument("output_dir", type=Path, help="Directory to save hooked libs")
args = parser.parse_args()

args.output_dir.mkdir(parents=True, exist_ok=True)

if not args.target_dir.exists():
    raise NotFound("Target dir")

if args.config.name == "hooks.xml":
    args.config = args.target_dir / "hooks.xml"
if not args.config.exists():
    raise NotFound("Hook config")

with open(args.config) as xml_file:
    parser = ET.parse(xml_file)

shook = parser.getroot()
cc = shook.find("compiler").text

for lib in shook.iterfind("lib_hook"):
    lib_path = (args.target_dir / lib.attrib["path"]).name
    print(f"Patching {lib_path}...")

    target = lief.parse(lib_path)

    cs_arch = lib.find("arch")
    cs_mode = lib.find("mode")
    if cs_arch is None or cs_mode is None:
        raise NotFound("Arch or mode for hooked liblary")
    asm = Assembler(target, cs_arch.text, cs_mode.text)
    ftb = FuncTable(target, asm)
    cmpl = Compile(cc, args.target_dir)
    inj = Inject(target, asm)

    # parse included libs
    for lib in lib.find("include").iterfind("lib"):
        kind = inc.attrib.get("kind")
        if kind == "system":
            inc_value = f"#include <{inc}>"
        elif kind == "local":
            inc_value = f'#include "{inc}"'
        else:
            raise Wrong("Kind of include lib")
        cmpl.include_lib(inc_value)

    # parse included funcs
    for inc in lib.find("include").iterfind("func"):
        kind = inc.attrib.get("kind")
        name = inc.text
        if kind == "import":
            addr = ftb.load_import(name)
        elif kind == "symbol":
            addr = ftb.load_symbol(name)
        else:
            raise Wrong("Kind of include func")
        proto = inc.attrib.get("proto")
        cmpl.include_func(name, proto, addr)

    # define included libs&funcs
    cmpl.assemble_transl()

    # compile hooks
    for to_hook in lib.iterfind("hook"):
        fnc_name = to_hook.attrib.get("name")
        fnc_proto = to_hook.attrib.get("proto")
        fnc_code = to_hook.text

        print(f"Compiling hook for {fnc_name}")
        cmpl.add_func_to_transl(fnc_name, fnc_proto, fnc_code)
    segm_addr = target.get_section(".shook").virtual_address
    funcs_info = cmpl.compile_transl(segm_addr)

    print("Patching the hook(s)...")
    content = []
    for func in funcs_info.values():
        arr = func["content"]
        offset = func["offset"] + segm_addr
        content += asm.patch_sub_values(arr, offset)
    inj.shook_sect_fill(content)

    for to_hook in lib.iterfind("hook"):
        fnc_name = to_hook.attrib["name"]
        print(f"Hooking {fnc_name}")

        fnc_offset = target.get_static_symbol(fnc_name).value  # address of func
        # offset of func in created section
        payl_offset = funcs_info[fnc_name]["offset"]
        inj.hook(fnc_name, fnc_offset, payl_offset)

    target.write((args.output_dir / lib.attrib["path"]).name)
print("Lib(s) patched")
