import asyncio
from fireworks.flumina import FluminaModule, main as flumina_main
import fireworks.flumina.route as route
from pydantic import BaseModel
from safetensors import safe_open
from fastapi import WebSocket
import torch
from typing import Dict
import os

# Define your request and response schemata here
class ModuleRequest(BaseModel):
    input_val: int


class ModuleResponse(BaseModel):
    output_val: float


class MyFluminaModule(FluminaModule):
    def __init__(self):
        super().__init__()
        # Add your initialization logic here
        #
        # Example below
        self.embedding = torch.nn.Embedding(5, 1024)
        self.linear = torch.nn.Linear(1024, 1)

        self.linear_addons: Dict[str, torch.nn.Linear] = {}
        self.active_addon = None

        # If running in a multi-GPU distributed scenario, initialize the PyTorch distributed
        # Process group
        if int(os.environ.get("WORLD_SIZE", "1")) > 1:
            torch.distributed.init_process_group()

    def forward(self, input_val: int):
        # Add your inference logic here
        #
        # Example below
        if input_val > self.embedding.weight.shape[0]:
            raise ValueError(f"Input value must be less than {self.embedding.weight.shape[0]}")
        x = self.embedding(torch.tensor([input_val], device=self.embedding.weight.device))
        if self.active_addon is not None:
            x = self.linear_addons[self.active_addon](x)
        else:
            x = self.linear(x)
        # Torch distributed operation for demonstration
        if torch.distributed.is_initialized():
            torch.distributed.all_reduce(x)
        return x

    @route.post("/infer")
    async def infer(self, input: ModuleRequest):
        # Add your endpoint logic here
        #
        # Example below
        model_out = self(input.input_val)
        return ModuleResponse(output_val=model_out.item())

    # Addon interface
    def load_addon(
        self, addon_account_id: str, addon_model_id: str, addon_type: str, addon_data_path: os.PathLike
    ):
        # Add logic for loading an addon from disk
        #
        # Example below
        if addon_type != "linear_addon":
            raise ValueError(f"Invalid addon type {addon_type}")

        new_linear_layer = torch.nn.Linear(1024, 1).to('cuda')

        with safe_open(
            os.path.join(addon_data_path, "model.safetensors"),
            framework="pt",
            device="cuda",
        ) as f:
            state_dict = {k: f.get_tensor(k) for k in f.keys()}

        new_linear_layer.load_state_dict(state_dict)

        name = f"{addon_account_id}/{addon_model_id}"
        self.linear_addons[name] = new_linear_layer

    def unload_addon(
        self, addon_account_id: str, addon_model_id: str, addon_type: str
    ):
        name = f"{addon_account_id}/{addon_model_id}"
        assert name in self.linear_addons
        self.linear_addons.pop(name)

    def activate_addon(self, addon_account_id: str, addon_model_id: str):
        name = f"{addon_account_id}/{addon_model_id}"
        if self.active_addon is not None:
            raise ValueError(f"Multiple active addons not supported")

        self.active_addon = name

    def deactivate_addon(self, addon_account_id: str, addon_model_id: str):
        name = f"{addon_account_id}/{addon_model_id}"
        if self.active_addon != name:
            raise ValueError(
                f"Tried to deactivate addon {name} but it is not currently active"
            )

        self.active_addon = None

    @route.websocket("/ws")
    async def ws(self, websocket: WebSocket):
        # WebSockets are also supported for fast real-time experiences
        await websocket.accept()
        # Add your WebSocket logic here
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(data)

if __name__ == "__flumina_main__":
    f = MyFluminaModule()
    flumina_main(f)

if __name__ == "__main__":
    f = MyFluminaModule()

    # Add your offline testing logic here
    #
    # Example below

    # Test module call
    x = 3
    out = f(x)
    assert isinstance(out.item(), float)

    # Test endpoint
    out = asyncio.run(f.infer(ModuleRequest(input_val=x)))
    assert isinstance(out.output_val, float)

    # Test linear addon
    f.load_addon("my_account", "my_addon", "linear_addon", "addon/data")
    f.activate_addon("my_account", "my_addon")
    out = f(x)
    f.deactivate_addon("my_account", "my_addon")
    assert isinstance(out.item(), float)
