#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2022 Stéphane Caron
# Copyright 2023 Inria

"""
Command-line script to communicate with Upkie's servos.
"""

import argparse
import asyncio
import math
import os
import sys
from dataclasses import dataclass
from typing import Dict, List
from pathlib import Path

import moteus
import moteus_pi3hat

COMMANDS = ["rezero", "stats", "stop"]


@dataclass
class Servo:
    id: int
    bus: int


class UpkieTool:

    """
    Interface to send the same command to all servos in a given layout.

    Attributes:
        controllers: List of moteus controllers, one for each servo.
        transport: pi3hat transport.
    """

    transport: moteus_pi3hat.Pi3HatRouter
    controllers: List[moteus.Controller]

    def __init__(self):
        """Initialize servo tool."""
        servos = self.__get_servos()
        servo_bus_map = self.__get_servo_bus_map(servos)
        transport = moteus_pi3hat.Pi3HatRouter(servo_bus_map=servo_bus_map)
        controllers = [
            moteus.Controller(id=servo.id, transport=transport)
            for servo in servos
        ]
        self.controllers = controllers
        self.transport = transport

    def __get_servos(self) -> List[Servo]:
        """
        Get list of servos from a pi3hat config string.

        Returns:
            List of servos.
        """
        servos = []
        pi3hat_cfg = "1=1,2,3;2=4,5,6"
        buses = filter(None, pi3hat_cfg.split(";"))
        for bus in buses:
            bus_id, servo_cfg = bus.split("=")
            bus_servos = filter(None, servo_cfg.split(","))
            for servo in bus_servos:
                servo_id = int(servo)
                servos.append(Servo(id=servo_id, bus=int(bus_id)))
        return servos

    def __get_servo_bus_map(self, servos: List[Servo]) -> Dict[int, List[int]]:
        """
        Extract servo bus -> id map from servo layout.

        Args:
            servos: List of servos.

        Returns:
            Dictionary with a list of servo IDs for each bus ID.
        """
        servo_bus_map: Dict[int, List[int]] = {}
        for servo in servos:
            if servo.bus not in servo_bus_map:
                servo_bus_map[servo.bus] = []
            servo_bus_map[servo.bus].append(servo.id)
        return servo_bus_map

    async def set_stop(self) -> None:
        """
        Send stop commands to moteus controllers.
        """
        print("Sending stop commands to all motors... ", end='')
        await self.transport.cycle(
            [controller.make_stop() for controller in self.controllers]
        )
        print("done")

    async def set_rezero(self) -> None:
        """
        Send rezero commands to moteus controllers.
        """
        print("Sending rezero commands to all motors... ", end="")
        await self.transport.cycle(
            [controller.make_rezero() for controller in self.controllers]
        )
        print("done")

    async def print_stats(self) -> None:
        """
        Query controllers for their current stats.
        """

        def get_result_id(result) -> int:  # help mypy
            result_id: int = result.id
            return result_id

        results = await self.transport.cycle(
            [
                controller.make_stop(query=True)
                for controller in self.controllers
            ]
        )
        sorted_results = sorted(results, key=get_result_id)
        print(
            f"{'id':2} "
            f"{'Mode':6} "
            f"{'Position (rad)':14} "
            f"{'Velocity (rad / s)':18} "
            f"{'Torque (N * m)':14}"
        )
        print(
            f"{'--':2} "
            f"{'----':6} "
            f"{'--------------':14} "
            f"{'------------------':18} "
            f"{'--------------':14}"
        )
        for result in sorted_results:
            mode = result.values[moteus.Register.MODE]
            position = result.values[moteus.Register.POSITION]
            velocity = result.values[moteus.Register.VELOCITY]
            torque = result.values[moteus.Register.TORQUE]
            position_rad = round(position * 2.0 * math.pi, 3)
            velocity_rps = round(velocity * 2 * math.pi, 3)
            torque_N_m = round(torque, 3)
            print(
                f"{result.id:2} "
                f"{mode:6} "
                f"{position_rad:14} "
                f"{velocity_rps:18} "
                f"{torque_N_m:14}"
            )


async def main(upkie_tool: UpkieTool, command: str) -> None:
    """
    Query controllers then make sure they are stopped. Clears any fault.

    Args:
        upkie_tool: Instance of the UpkieTool class.
        command: Command keyword.
    """
    await upkie_tool.set_stop()
    try:
        if command == "rezero":
            await upkie_tool.set_rezero()
            Path('/tmp/rezero_success').touch()
        if command == "stats":
            await upkie_tool.print_stats()
        if command == "stop":
            await upkie_tool.set_stop()
    except Exception as e:
        print(f"Ignoring exception: {e}")
    await upkie_tool.set_stop()


def parse_command_line_arguments() -> argparse.Namespace:
    """
    Parse command-line arguments.

    Returns:
        Namespace resulting from parsing command-line arguments.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        "command",
        help="Command to execute",
        choices=COMMANDS,
    )
    return parser.parse_args()


if __name__ == "__main__":
    if os.geteuid() != 0:
        args = ["sudo", "-E", sys.executable] + sys.argv + [os.environ]
        os.execlpe("sudo", *args)
    args = parse_command_line_arguments()
    upkie_tool = UpkieTool()
    asyncio.run(main(upkie_tool, args.command))
