#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2022 Stéphane Caron
# Copyright 2023 Inria

import argparse
import asyncio
import math
import os
import sys
from dataclasses import dataclass
from typing import Dict, List
from pathlib import Path

import moteus
import moteus_pi3hat

COMMANDS = ["rezero", "stats", "stop"]


@dataclass
class Servo:
    id: int
    bus: int


class UpkieTool:

    """
    Interface to send the same command to all servos in a given layout.

    Attributes:
        controllers: List of moteus controllers, one for each servo.
        transport: pi3hat transport.
    """

    transport: moteus_pi3hat.Pi3HatRouter
    controllers: List[moteus.Controller]

    def __init__(self):
        """Initialize servo tool."""
        servos = self.__get_servos()
        servo_bus_map = self.__get_servo_bus_map(servos)
        transport = moteus_pi3hat.Pi3HatRouter(servo_bus_map=servo_bus_map)
        controllers = [
            moteus.Controller(id=servo.id, transport=transport)
            for servo in servos
        ]
        self.controllers = controllers
        self.transport = transport

    def __get_servos(self) -> List[Servo]:
        """
        Get list of servos from a pi3hat config string.

        Returns:
            List of servos.
        """
        servos = []
        pi3hat_cfg = "1=1,2,3;2=4,5,6"
        buses = filter(None, pi3hat_cfg.split(";"))
        for bus in buses:
            bus_id, servo_cfg = bus.split("=")
            bus_servos = filter(None, servo_cfg.split(","))
            for servo in bus_servos:
                servo_id = int(servo)
                servos.append(Servo(id=servo_id, bus=int(bus_id)))
        return servos

    def __get_servo_bus_map(self, servos: List[Servo]) -> Dict[int, List[int]]:
        """
        Extract servo bus -> id map from servo layout.

        Args:
            servos: List of servos.

        Returns:
            Dictionary with a list of servo IDs for each bus ID.
        """
        servo_bus_map: Dict[int, List[int]] = {}
        for servo in servos:
            if servo.bus not in servo_bus_map:
                servo_bus_map[servo.bus] = []
            servo_bus_map[servo.bus].append(servo.id)
        return servo_bus_map

    async def set_stop(self) -> None:
        """
        Send stop commands to moteus controllers.
        """
        print("Sending stop commands to all motors... ", end='')
        await self.transport.cycle(
            [controller.make_stop() for controller in self.controllers]
        )
        print("done")

    async def set_rezero(self) -> None:
        """
        Send rezero commands to moteus controllers.
        """
        print("Sending rezero commands to all motors... ", end="")
        await self.transport.cycle(
            [controller.make_rezero() for controller in self.controllers]
        )
        print("done")

    async def print_stats(self) -> None:
        """
        Query controllers for their current stats.
        """

        def get_result_id(result) -> int:  # help mypy
            result_id: int = result.id
            return result_id

        results = await self.transport.cycle(
            [
                controller.make_stop(query=True)
                for controller in self.controllers
            ]
        )
        sorted_results = sorted(results, key=get_result_id)
        print(
            f"{'id':2} "
            f"{'Mode':6} "
            f"{'Position (rad)':14} "
            f"{'Velocity (rad / s)':18} "
            f"{'Torque (N * m)':14}"
        )
        print(
            f"{'--':2} "
            f"{'----':6} "
            f"{'--------------':14} "
            f"{'------------------':18} "
            f"{'--------------':14}"
        )
        for result in sorted_results:
            mode = result.values[moteus.Register.MODE]
            position = result.values[moteus.Register.POSITION]
            velocity = result.values[moteus.Register.VELOCITY]
            torque = result.values[moteus.Register.TORQUE]
            position_rad = round(position * 2.0 * math.pi, 3)
            velocity_rps = round(velocity * 2 * math.pi, 3)
            torque_N_m = round(torque, 3)
            print(
                f"{result.id:2} "
                f"{mode:6} "
                f"{position_rad:14} "
                f"{velocity_rps:18} "
                f"{torque_N_m:14}"
            )


async def main(upkie_tool: UpkieTool, command: str) -> None:
    """
    Query controllers then make sure they are stopped. Clears any fault.

    Args:
        upkie_tool: Instance of the UpkieTool class.
        command: Command keyword.
    """
    await upkie_tool.set_stop()
    try:
        if command == "rezero":
            await upkie_tool.set_rezero()
            Path('/tmp/rezero_success').touch()
        if command == "stats":
            await upkie_tool.print_stats()
        if command == "stop":
            await upkie_tool.set_stop()
    except Exception as e:
        print(f"Ignoring exception: {e}")
    await upkie_tool.set_stop()


def parse_command_line_arguments() -> argparse.Namespace:
    """
    Parse command-line arguments.

    Returns:
        Namespace resulting from parsing command-line arguments.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "command",
        help="Command to execute",
        choices=COMMANDS,
    )
    return parser.parse_args()


if __name__ == "__main__":
    if os.geteuid() != 0:
        args = ["sudo", "-E", sys.executable] + sys.argv + [os.environ]
        os.execlpe("sudo", *args)
    args = parse_command_line_arguments()
    upkie_tool = UpkieTool()
    asyncio.run(main(upkie_tool, args.command))
