#!python
# -*- coding: utf-8 -*-
# Author: Stefan Agner <stefan@agner.ch>

import asyncio
import click
import functools
import os

import bellows.config as config
import bellows.ezsp
import bellows.types

import logging
import serial
import time

import re
import pexpect
from pexpect import fdpexpect

from xmodem import XMODEM

LOGGER = logging.getLogger(__name__)

def background(f):
    @functools.wraps(f)
    def inner(*args, **kwargs):
        return asyncio.run(f(*args, **kwargs))

    return inner

async def setup(dev, baudrate, cbh=None, configure=True):
    device_config = {
        config.CONF_DEVICE_PATH: dev,
        config.CONF_DEVICE_BAUDRATE: baudrate,
        config.CONF_FLOW_CONTROL: config.CONF_FLOW_CONTROL_DEFAULT,
    }
    if not await bellows.ezsp.EZSP.probe(device_config):
        return None
    s = bellows.ezsp.EZSP(device_config)
    if cbh:
        s.add_callback(cbh)
    try:
        await s.connect()
    except Exception as e:
        LOGGER.error(e)
        raise click.Abort()
    LOGGER.debug("Connected. Resetting.")
    await s.reset()
    await s.version()

    return s

async def launch_bootloader(ezsp):
    try:
        click.echo("Launching bootloader in mode recovery mode via EZSP.")
        res = await ezsp.launchStandaloneBootloader(0x01)
        if res[0] != bellows.types.EmberStatus.SUCCESS:
            click.echo(f"Couldn't launch bootloader: {res[0]}")
            return False
    except asyncio.exceptions.TimeoutError:
        # Timeouts are expected here as we enter the boot loader.
        pass

    return True

def upload_firmware(port, baudrate, firmware):
    # Communicate with Bootloader using pyserial.
    sp = serial.Serial(
        port = port,
        baudrate = baudrate, 
        stopbits=serial.STOPBITS_ONE,
        bytesize=serial.EIGHTBITS
    )
    fd = sp.fileno()
    m = fdpexpect.fdspawn(fd, timeout=2)

    blbanner = re.compile(b"Gecko Bootloader v([0-9\.]*)\r\n")
    m.sendline("")
    idx = m.expect([blbanner, pexpect.TIMEOUT])
    if idx == 1:
        click.echo("Failed to launch the bootloader. Bootloader banner did not appear.")
        return

    blversion = m.match.group(1).decode()
    click.echo("Bootloader detected successfully.")
    click.echo(f"Bootloader version: {blversion}")

    m.expect("1. upload gbl")
    m.expect("BL > ")
    m.send("1")
    m.expect("C")

    click.echo("Starting firmware upload...")
    with click.progressbar(length=os.fstat(firmware.fileno()).st_size,
                           label='Firmware update') as progress:
        def getc(size, timeout=1):
            data = sp.read(size)
            #print("get", data.hex())
            return data

        def putc(data, timeout=1):
            #print("putc", data.hex())
            sp.write(data)
            progress.update(len(data))
            #time.sleep(0.001)
        modem = XMODEM(getc, putc)
        sentcheck = modem.send(firmware, quiet=True, retry=0)

    if not sentcheck:
        print("XMODEM transfer failed.")

    idx = m.expect(["Serial upload complete", "Serial upload aborted"])
    if idx == 0:
        click.echo("Bootloader reported successful upload.")
        click.echo("Starting flashed firmware...")
        m.expect("BL > ")
        m.send("2")
    elif idx == 1:
        click.echo("Bootloader reported aborted upload:")
        errortobanner = re.compile(b"(.*)\r\nGecko Bootloader", re.DOTALL)
        m.expect(errortobanner)
        error = m.match.group(1).decode().strip()
        print(error)


@click.group()
@click.option('--device', envvar='SILABS_FIRMWARE_DEVICE', required=True)
@click.option('--baudrate', envvar='SILABS_FIRMWARE_BAUDRATE', default=115200, show_default=True)
@click.pass_context
def main(ctx, device, baudrate):
    """Silicon Labs firmware upgrade utility"""
    ctx.obj = {
        "device": device,
        "baudrate": baudrate,
    }

@main.command()
@click.pass_context
@background
async def info(ctx):
    ezsp = await setup(ctx.obj["device"], ctx.obj["baudrate"])

    brd_manuf, brd_name, version = await ezsp.get_board_info()
    click.echo(f"Manufacturer: {brd_manuf}")
    click.echo(f"Board name: {brd_name}")
    click.echo(f"EmberZNet version: {version}")

    version, plat, micro, phy = await ezsp.getStandaloneBootloaderVersionPlatMicroPhy()
    if version == 0xFFFF:
        click.echo("No boot loader installed")
        ezsp.close()
        return

    click.echo(
        (
            f"bootloader version: 0x{version:04x}, nodePlat: 0x{plat:02x}, "
            f"nodeMicro: 0x{micro:02x}, nodePhy: 0x{phy:02x}"
        )
    )

    ezsp.close()

@main.command()
@click.pass_context
@click.option('--firmware', type=click.File('rb'), required=True, envvar='SILABS_FIRMWARE_FILE')
@click.option('--ezsp-reset/--no-ezsp-reset', default=True, show_default=True)
@click.option('--cm4-gpio-reset', is_flag=True, default=False, show_default=True)
@background
async def flash(ctx, firmware, ezsp_reset, cm4_gpio_reset):
    device = ctx.obj["device"]
    baudrate = ctx.obj["baudrate"]

    if ezsp_reset:
        click.echo("Trying to connect using EZSP...")
        ezsp = await setup(device, baudrate)
        if ezsp is None:
            click.echo("Failed communicating using EZSP, assuming we are in bootloader.")
        else:
            await launch_bootloader(ezsp)
            ezsp.close()
        ezsp = None
    if cm4_gpio_reset:
        import gpiod
        chip = gpiod.Chip('0', gpiod.Chip.OPEN_BY_NUMBER)
        bl = chip.get_line(24)
        reset = chip.get_line(25)
        bl.request("silabs-flasher")
        reset.request("silabs-flasher")

        reset.set_direction_output(0) # Assert Reset
        bl.set_direction_output(0) # 0=BL mode, 1=Firmware
        time.sleep(0.1)
        reset.set_direction_output(1) # Deassert Reset
        time.sleep(0.1)
        # This clears all GPIO leaving them as input/without pulls
        # External pulls will make sure the SoC will enter Firmware
        # on next reset.
        reset.set_direction_input()
        bl.set_direction_input()
        chip.close()

    upload_firmware(device, baudrate, firmware)

if __name__ == '__main__':
    main()

