#!/usr/bin/python3
#
# Copyright 2021 Josh Pieper, jjp@pobox.com.
# Copyright 2023 Inria
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Use the raw CAN API to query the power dist board.

See https://github.com/mjbots/power_dist/blob/main/docs/reference.md
"""

import asyncio
import moteus
import moteus_pi3hat
import numpy as np
import os
import struct
import sys


TYPECODES = {"int8": 0, "int16": 1, "int32": 2, "f32": 3}


class Stream:
    """Stream class to decode data fields.

    Source: ``utils/decode_can_frame.py`` in
    [mjbots/moteus](https://github.com/mjbots/moteus/)."""

    def __init__(self, data):
        self.data = data

    def remaining(self):
        return len(self.data)

    def _read_byte(self):
        result, self.data = self.data[0:1], self.data[1:]
        return result[0]

    def _read_value(self, size):
        result, self.data = self.data[0:size], self.data[size:]
        return result

    def read_struct(self, fmt):
        s = struct.Struct(fmt)
        data = self._read_value(s.size)
        return data, s.unpack(data)[0]

    def read_int8(self):
        return self.read_struct("<b")

    def read_int16(self):
        return self.read_struct("<h")

    def read_int32(self):
        return self.read_struct("<i")

    def read_f32(self):
        return self.read_struct("<f")

    def read_varuint(self):
        result_number = 0
        result_data = bytes([])
        shift = 0
        for i in range(5):
            this_byte = self._read_byte()
            result_data = result_data + bytes([this_byte])
            result_number |= (this_byte & 0x7F) << shift
            shift += 7

            if (this_byte & 0x80) == 0:
                return result_data, result_number

        raise RuntimeError(f"Invalid varuint {result_data.hex()}")

    def read_type(self, typecode):
        if typecode == TYPECODES["int8"]:
            return self.read_int8()
        elif typecode == TYPECODES["int16"]:
            return self.read_int16()
        elif typecode == TYPECODES["int32"]:
            return self.read_int32()
        elif typecode == TYPECODES["f32"]:
            return self.read_f32()
        raise RuntimeError(f"Unknown type: {typecode}")


async def main(sampling_duration: float = 0.5, wait_dt: float = 0.005):
    can_config = {5: moteus_pi3hat.CanConfiguration()}
    transport = moteus_pi3hat.Pi3HatRouter(can=can_config)
    voltages = []
    currents = []

    nb_cycles = int(sampling_duration / wait_dt)
    for cycle in range(nb_cycles):
        raw_message = moteus.Command()
        raw_message.raw = True
        raw_message.destination = 32
        raw_message.bus = 5
        raw_message.reply_required = True

        # 0x80: ask for reply, and we (the pi3hat) have ID zero
        # 0x20: ID of the power dist board (0x20 == 32)
        # https://github.com/mjbots/moteus/blob/main/docs/reference.md#can-id
        raw_message.arbitration_id = 0x8020

        # 0x1e: read two floats
        # https://github.com/mjbots/moteus/blob/main/docs/reference.md#a1b-read-registers
        # 0x10: starting from Output Voltage register
        # https://github.com/mjbots/power_dist/blob/main/docs/reference.md#0x010---output-voltage
        raw_message.data = bytes.fromhex("1e10")

        results = await transport.cycle(
            [raw_message],
            force_can_check=(1 << 5),
        )

        for result in results:
            if hasattr(result, "id"):  # moteus result
                continue

            stream = Stream(bytes.fromhex("".join(result.data.hex())))
            while stream.remaining():
                data, cmd = stream.read_int8()
                if cmd == 0x50:  # NOP
                    continue

                if data.hex() != "2e":  # 0x2e: reply, two floats
                    print(f"Unexpected reply: {data.hex()=}")

                typecode = (cmd & 0b00001100) >> 2
                if typecode != TYPECODES["f32"]:
                    print(f"Unexpected type: {typecode=} not f32")
                    continue

                length = cmd & 0b00000011
                if length != 2:
                    print(f"Unexpected {length} != 2 registers")
                    continue

                start_reg_data, start_reg = stream.read_varuint()
                if start_reg_data.hex() != "10":
                    print("Not starting at the Output Voltage register 0x10")
                    continue

                _, voltage = stream.read_f32()
                _, current = stream.read_f32()
                voltages.append(voltage)
                currents.append(current)

        await asyncio.sleep(wait_dt)

    print(f"Output voltage: {np.average(voltages):.1f} V")
    print(f"Output current: {np.average(currents):.1f} A")
    print(f"Averaged over {sampling_duration:.2f} seconds")


if __name__ == "__main__":
    if os.geteuid() != 0:
        args = ["sudo", "-E", sys.executable] + sys.argv + [os.environ]
        os.execlpe("sudo", *args)
    asyncio.run(main())
