#!/usr/bin/env python3
#
# A simple implementation of COMET emulator.
# Copyright (c) 2021, Hiroyuki Ohsaki.
# All rights reserved.
#

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import collections
import re
import struct
import sys

from perlcompat import die, warn, getopts
import tbdump

VERSION = 0.2
DEBUG = False

# addresses of IN/OUT/EXI system calls
SYS_IN = 0xfff0
SYS_OUT = 0xfff2
SYS_EXIT = 0xfff4

# values of the flag register
FR_PLUS = 0
FR_ZERO = 1
FR_MINUS = 2

# the top of the stack, which is the upper limit of the stack space.
STACK_TOP = 0xff00

# maximum/minimum of signed value
MAX_SIGNED = 32767
MIN_SIGNED = -32768

Inst = collections.namedtuple('Inst', ['type', 'nemonic'])
Command = collections.namedtuple('Command', ['regexp', 'subr', 'need_print'])

# COMET instructions
# Each entry is a object code, and its associated value is a hash
# having two keys: ID is the mnemonic of the instruction, and TYPE is
# the type of it.
INSTTBL = {
    0x10: Inst('op1', 'LD'),
    0x11: Inst('op1', 'ST'),
    0x12: Inst('op1', 'LEA'),
    0x20: Inst('op1', 'ADD'),
    0x21: Inst('op1', 'SUB'),
    0x30: Inst('op1', 'AND'),
    0x31: Inst('op1', 'OR'),
    0x33: Inst('op1', 'EOR'),
    0x40: Inst('op1', 'CPA'),
    0x41: Inst('op1', 'CPL'),
    0x50: Inst('op1', 'SLA'),
    0x51: Inst('op1', 'SRA'),
    0x52: Inst('op1', 'SLL'),
    0x53: Inst('op1', 'SRL'),
    0x60: Inst('op2', 'JPZ'),
    0x61: Inst('op2', 'JMI'),
    0x62: Inst('op2', 'JNZ'),
    0x63: Inst('op2', 'JZE'),
    0x64: Inst('op2', 'JMP'),
    0x70: Inst('op2', 'PUSH'),
    0x71: Inst('op3', 'POP'),
    0x80: Inst('op2', 'CALL'),
    0x81: Inst('op4', 'RET'),
}

def usage():
    die(f"""\
usage: {sys.argv[0]} [-q] [com-file]
  -q    ???
""")

def debug(msg):
    if DEBUG:
        warn('** ' + msg)

def signed(val):
    return struct.unpack('>h', struct.pack('>H', val))[0]

def unsigned(val):
    return struct.unpack('>H', struct.pack('>h', val))[0]

def parse_number(v):
    if type(v) == int:
        return int(v) & 0xffff
    elif type(v) == str:
        if re.search(r'^[-+]?\d+$', v):
            return int(v) & 0xffff
        else:
            m = re.search(r'^#([\da-zA-Z]+)$', v)
            if m:
                # convert hex to decimal
                return int(m.group(1), base=16) & 0xffff
    return None

# ----------------------------------------------------------------
class Memory:
    def __init__(self):
        self._memory = [0] * 0x10000
        self.clear()

    def __setitem__(self, addr, val):
        self._memory[addr] = val

    def __getitem__(self, addr):
        return self._memory[addr]

    def clear(self):
        for addr in range(0x10000):
            self[addr] = 0

class State:
    def __init__(self):
        self.pc = 0
        self.fr = FR_ZERO
        self.gr = [0, 0, 0, 0, STACK_TOP]

class Comet:
    def __init__(self):
        self.memory = Memory()
        self.state = State()
        self.breakpoints = []

    def __repr__(self):
        return f'Comet({self.state.pc:04x}, {self.state.gr[0]:02x}, {self.state.gr[1]:02x}, {self.state.gr[2]:02x}, {self.state.gr[3]:02x}, {self.state.gr[4]:02x} {self.state.fr:#b})'

    def decode(self, addr=None):
        if addr is None:
            addr = self.state.pc
        word = self.memory[addr]
        inst = word >> 8
        gr = (word >> 4) & 0xf
        xr = word & 0xf
        adr = self.memory[addr + 1]
        return word, inst, gr, adr, xr

    def parse(self, addr=None):
        """ Disassemble the object from the PC, and return strings for the
        instruction and the operand.???"""
        debug(f'parse({self}, {addr})')
        if addr is None:
            addr = self.state.pc

        # decode the instruction at ADDR
        word, inst, gr, adr, xr = self.decode(addr)

        if inst in INSTTBL:
            nemonic = INSTTBL[inst].nemonic
            type_ = INSTTBL[inst].type
            # instructions with GR, adr, and XR
            if type_ == 'op1':
                opr = f'GR{gr}, #{adr:04x}'
                if xr > 0:
                    opr += f', GR{xr}'
                size = 2
            # instructions with adr and XR
            elif type_ == 'op2':  # with adr, (XR)
                opr = f'#{adr:04x}'
                if xr > 0:
                    opr += f', GR{xr}'
                size = 2
            # instructions with GR
            elif type_ == 'op3':  # only with GR
                opr = f'GR{gr}'
                size = 2
            # instructions without operand
            elif type_ == 'op4':  # no operand
                opr = ''
                size = 1
        else:
            # interpret as a binary word by default
            nemonic = 'DC'
            opr = f'#{word:04x}'
            size = 1

        # for IN/OUT/EXIT system calls
        if addr == SYS_IN:
            nemonic, opr = 'IN', 'SYSTEM CALL'
            size = 2
        elif addr == SYS_OUT:
            nemonic, opr = 'OUT', 'SYSTEM CALL'
            size = 2
        elif addr == SYS_EXIT:
            nemonic, opr = 'EXIT', 'SYSTEM CALL'
            size = 2

        return nemonic, opr, size

    def update_fr(self, val):
        if val & 0x8000:
            self.state.fr = FR_MINUS
        elif val == 0:
            self.state.fr = FR_ZERO
        else:
            self.state.fr = FR_PLUS

    def step_exec(self):
        """Execute one instruction from the PC --- evaluate the intruction,
        update registers, and advance the PC by the instruction's size."""
        debug(f'step_exec({self})')
        pc = self.state.pc  # to be updated
        fr = self.state.fr  # read-only
        regs = self.state.gr  # alias

        # calcurate the effective address
        word, inst, gr, adr, xr = self.decode(pc)
        eadr = adr
        if 1 <= xr <= 4:
            eadr += regs[xr]
        eadr &= 0xffff

        # obtain the mnemonic and the operand for the current address
        nemonic, opr, size = self.parse()
        if nemonic == 'IN':
            self.exec_in()
            return
        elif nemonic == 'OUT':
            self.exec_out()
            return
        elif nemonic == 'EXIT':
            sys.exit(1)
        elif nemonic == 'LD':
            regs[gr] = self.memory[eadr]
            pc += 2
        elif nemonic == 'ST':
            self.memory[eadr] = regs[gr]
            pc += 2
        elif nemonic == 'LEA':
            regs[gr] = eadr
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'ADD':
            regs[gr] += self.memory[eadr]
            regs[gr] &= 0xffff
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'SUB':
            regs[gr] -= self.memory[eadr]
            regs[gr] &= 0xffff
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'AND':
            regs[gr] &= self.memory[eadr]
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'OR':
            regs[gr] |= self.memory[eadr]
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'EOR':
            regs[gr] ^= self.memory[eadr]
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'CPA':
            v = signed(regs[gr]) - signed(self.memory[eadr])
            v = max(MIN_SIGNED, min(MAX_SIGNED, v))
            self.update_fr(unsigned(v))
            pc += 2
        elif nemonic == 'CPL':
            v = regs[gr] - self.memory[eadr]
            v = max(MIN_SIGNED, min(MAX_SIGNED, v))
            self.update_fr(unsigned(v))
            pc += 2
        elif nemonic == 'SLA':
            v = regs[gr] & 0x8000
            regs[gr] <<= eadr
            regs[gr] |= v
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'SRA':
            v = regs[gr]
            if v & 0x8000:
                v &= 0x7fff
                v >>= eadr
                v += ((0x7fff >> eadr) ^ 0xffff)
            else:
                v >>= eadr
            regs[gr] = v
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'SLL':
            regs[gr] <<= eadr
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'SRL':
            regs[gr] >>= eadr
            self.update_fr(regs[gr])
            pc += 2
        elif nemonic == 'JPZ':
            if fr != FR_MINUS:
                pc = eadr
            else:
                pc += 2
        elif nemonic == 'JMI':
            if fr == FR_MINUS:
                pc = eadr
            else:
                pc += 2
        elif nemonic == 'JNZ':
            if fr != FR_ZERO:
                pc = eadr
            else:
                pc += 2
        elif nemonic == 'JZE':
            if fr == FR_ZERO:
                pc = eadr
            else:
                pc += 2
        elif nemonic == 'JMP':
            pc = eadr
        elif nemonic == 'PUSH':
            regs[4] -= 1
            self.memory[regs[4]] = eadr
            pc += 2
        elif nemonic == 'POP':
            regs[gr] = self.memory[regs[4]]
            regs[4] += 1
            pc += 2
        elif nemonic == 'CALL':
            regs[4] -= 1
            self.memory[regs[4]] = pc + 2
            pc = eadr
        elif nemonic == 'RET':
            pc = self.memory[regs[4]]
            regs[4] += 1
        else:
            die(f'illegal instruction {inst:02x} at {pc:04x}')

        self.state.pc = pc

    def exec_in(self):
        """Handler of the IN system call --- extract two arguments from the
        stack, read a line from STDIN, store it in specified place."""
        debug(f'exec_in({self})')
        pc = self.memory[self.state.gr[4]]
        len_addr = self.memory[self.state.gr[4] + 1]
        buf_addr = self.memory[self.state.gr[4] + 2]
        line = input('IN > ')  # prompt for input
        line = line[:80]  # must be shorter than 80 characters
        self.memory[len_addr] = len(line)
        for c in bytearray(line, encoding='ascii'):
            self.memory[buf_addr] = c
            buf_addr += 1
        self.state.pc = pc  # go back to the caller
        self.state.gr[4] += 1

    def exec_out(self):
        """Handler of the OUT system call --- extract two arguments from the
        stack, write a string to STDOUT."""
        debug(f'exec_out({self})')
        pc = self.memory[self.state.gr[4]]
        len_addr = self.memory[self.state.gr[4] + 1]
        buf_addr = self.memory[self.state.gr[4] + 2]
        size = self.memory[len_addr]
        print('OUT> ', end='')
        for n in range(size):
            c = self.memory[buf_addr + n] & 0xff
            print(chr(c), end='')
        print()
        self.state.pc = pc  # go back to the caller
        self.state.gr[4] += 1

# ----------------------------------------------------------------
def cmd_run(comet, *args):
    debug(f'cmd_run({comet}, {args})')
    while True:
        comet.step_exec()
        # check the PC is at one of breakpoints
        for n, addr in enumerate(comet.breakpoints):
            if comet.state.pc == addr:
                print(f'Breakpoint {n}, #{addr:04x}')
                return

def cmd_step(comet, *args):
    debug(f'cmd_step({comet}, {args})')
    try:
        count = parse_number(args[0])
    except IndexError:
        count = 1
    for n in range(count):
        comet.step_exec()

def cmd_break(comet, *args):
    debug(f'cmd_break({comet}, {args})')
    try:
        addr = parse_number(args[0])
        if addr is not None:
            comet.breakpoints.append(addr)
        else:
            warn(f'invalid breakpoint address "{args[0]}"')
    except IndexError:
        pass

def cmd_delete(comet, *args):
    debug(f'cmd_delete({comet}, {args})')
    try:
        n = parse_number(args[0])
        if n is not None:
            del comet.breakpoints[n - 1]
        else:
            resp = input('Delete all breakpoints? (y or n) ')
            if re.search(r'^[yY]', resp):
                comet.breakpoints.clear()
    except IndexError:
        pass

def cmd_info(comet, *args):
    debug(f'cmd_info({comet}, {args})')
    for n, addr in enumerate(comet.breakpoints):
        print(f'{n}: #{addr:04x}')

def cmd_print(comet, *args):
    debug(f'cmd_print({comet}, {args})')
    # obtain instruction and operand at the current PC
    inst, opr, size = comet.parse()
    gr = comet.state.gr
    fr = comet.state.fr
    print(""
          f"PC  #{comet.state.pc:04x} [ {inst} {opr} ]\n"
          f"GR0 #{gr[0]:04x} ({signed(gr[0]):6}) "
          f"GR1 #{gr[1]:04x} ({signed(gr[1]):6}) "
          f"GR2 #{gr[2]:04x} ({signed(gr[2]):6})\n"
          f"GR3 #{gr[3]:04x} ({signed(gr[3]):6}) "
          f"GR4 #{gr[4]:04x} ({signed(gr[4]):6}) "
          f"FR  {fr:#b} ({fr:6})\n")

def cmd_dump(comet, *args):
    debug(f'cmd_dump({comet}, {args})')
    try:
        addr = parse_number(args[0])
    except IndexError:
        addr = comet.state.pc
    for row in range(16):
        base = addr + (row << 3)
        print(f'{base:04x}', end='')
        for col in range(8):
            v = comet.memory[base + col]
            print(f' {v:04x}', end='')
        print(' ', end='')
        for col in range(8):
            v = comet.memory[base + col] & 0xff
            if 0x20 <= v <= 0x7f:
                c = chr(v)
            else:
                c = '.'
            print(c, end='')
        print()

def cmd_stack(comet, *args):
    debug(f'cmd_stack({comet}, {args})')
    addr = comet.state.gr[4]
    cmd_dump(comet, addr)

def cmd_file(comet, file):
    debug(f'cmd_file({comet}, {file})')
    with open(file, 'rb') as f:
        print(f'Reading object from {file}...', end='')

        # parse the file header
        header = f.read(16)
        if header[:4] != b'CASL':
            die(f'{file}: not a COMET object file')
        comet.state.pc = struct.unpack('>H', header[4:6])[0]

        # load object into the memory
        addr = 0
        while True:
            buf = f.read(2)
            if not buf:
                break
            if addr >= STACK_TOP:
                die('out of memory')
            comet.memory[addr] = struct.unpack('>H', buf)[0]
            addr += 1
        print('done.')

def cmd_jump(comet, *args):
    debug(f'cmd_jump({comet}, {args})')
    try:
        addr = parse_number(args[0])
        if addr is not None:
            comet.state.pc = parse_number(addr)
        else:
            warn(f'invalid jump address "{args[0]}"')
    except IndexError:
        pass

def cmd_memory(comet, *args):
    debug(f'cmd_memory({comet}, {args})')
    try:
        addr = parse_number(args[0])
        val = parse_number(args[1])
        if addr is not None and val is not None:
            comet.memory[addr] = val
        else:
            warn('invalid address "{args[0]}" or value "{args[1]}"')
    except IndexError:
        warn('memory command needs address and value')

def cmd_disasm(comet, *args):
    debug(f'cmd_disasm({comet}, {args})')
    try:
        addr = parse_number(args[0])
    except IndexError:
        addr = comet.state.pc
    for n in range(16):
        inst, opr, size = comet.parse(addr)
        print(f'#{addr:04x}\t{inst}\t{opr}')
        addr += size

def cmd_help(comet, *args):
    debug(f'cmd_help({comet}, {args})')
    print("""\
List of commands:

r,  run         Start execution of program.
s,  step        Step execution.  Argument N means do this N times.
b,  break       Set a breakpoint at specified address.
d,  del 	Delete some breakpoints.
i,  info        Print information on breakpoints.
p,  print       Print status of PC/FR/GR0/GR1/GR2/GR3/GR4 registers.
du, dump        Dump 128 words of memory image from specified address.
st, stack       Dump 128 words of stack image.
f,  file        Use FILE as program to be debugged.
j,  jump        Continue program at specifed address.
m,  memory      Change the memory at ADDRESS to VALUE.
di, disasm      Disassemble 32 words from specified address.
h,  help        Print list of commands.
q,  quit        Exit comet.""")

def cmd_quit(comet, *args):
    sys.exit(1)

CMDTBL = [
    Command('de|del', cmd_delete, False),
    Command('du|dump', cmd_dump, False),
    Command('b|break', cmd_break, False),
    Command('di|disasm', cmd_disasm, False),
    Command('f|file', cmd_file, True),
    Command('h|\?|help', cmd_help, False),
    Command('i|info', cmd_info, False),
    Command('j|jump', cmd_jump, True),
    Command('m|memory', cmd_memory, True),
    Command('p|print', cmd_print, False),
    Command('q|quit', cmd_quit, False),
    Command('r|run', cmd_run, True),
    Command('st|stack', cmd_stack, False),
    Command('s|step', cmd_step, True),
]

# ----------------------------------------------------------------
def parse_options(opt):
    if not opt.q:
        print(f"""\
This is COMET, version {VERSION}.
Copyright (c) 2021, Hiroyuki Ohsaki.
All rights reserved.""")
    if opt.d:
        global DEBUG
        DEBUG = True

def main():
    opt = getopts('qd') or usage()
    parse_options(opt)

    comet = Comet()
    if len(sys.argv) >= 2:
        file = sys.argv[1]
        cmd_file(comet, file)

    last_line = ''
    cmd_print(comet)
    while True:
        # show prompt and input command from STDIN
        line = input('comet> ')
        if line == '':
            line = last_line
        last_line = line
        cmd, *args = re.split(r'\s+', line)
        if not cmd:
            continue

        # execute command according to command
        for pattern, subr, need_print in CMDTBL:
            if re.search(r'^' + pattern, cmd):
                subr(comet, *args)
                if need_print:
                    cmd_print(comet)
                break
        else:
            print(f'undefined command: "{cmd}". Try "help"')

if __name__ == "__main__":
    main()
