Metadata-Version: 2.4
Name: sqs-fl-framework
Version: 1.2.0
Summary: Model-agnostic Federated Learning framework with Amazon SQS event bus
Author-email: Randini_Maliksha <maliksharandini@gmail.com>
License-Expression: MIT
Keywords: federated-learning,sqs,machine-learning,distributed
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: numpy>=1.24.0
Requires-Dist: pandas>=2.0.0
Requires-Dist: requests>=2.31.0
Requires-Dist: boto3>=1.34.0
Requires-Dist: python-dotenv>=1.0.0
Requires-Dist: pymongo>=4.6.0
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Requires-Dist: pytest-cov>=5.0; extra == "dev"
Requires-Dist: moto[sqs]>=5.0; extra == "dev"
Requires-Dist: black>=24.0; extra == "dev"
Requires-Dist: ruff>=0.4; extra == "dev"
Requires-Dist: mypy>=1.9; extra == "dev"

# fl-framework

A model-agnostic Federated Learning framework with an Amazon SQS event bus.  
Bring your own model — the framework handles the orchestration.

---

## Features

| Feature | Description |
|---|---|
| **Model-agnostic** | Wrap any model (scikit-learn, PyTorch, XGBoost, ensembles) with a simple interface |
| **SQS event bus** | Clients and server communicate via Amazon SQS long-polling — no direct network connections between nodes |
| **FedAvg aggregation** | Sample-weighted averaging for float parameters; quality-weighted selection for non-averageable blobs (e.g. tree models) |
| **Uniform aggregation** | Optional equal-weight strategy regardless of local dataset size |
| **In-process simulation** | Run full FL rounds without SQS or a real backend — ideal for testing and research |
| **Dynamic data loading** | `train_data_fn` / `eval_data_fn` callables are invoked per round — supports live DB queries |
| **Auto round advancement** | Server can automatically start the next round after aggregation |
| **Round timeout** | Configurable deadline; raises `RoundTimeoutError` if not enough clients submit in time |
| **Async client** | `FLClient.start_async()` runs in a background thread so it doesn't block your application |
| **Round lifecycle hooks** | `on_round_start` / `on_round_end` optional callbacks on the model for custom side-effects |
| **Environment-variable config** | Every setting has an env-var override — no code changes needed between environments |
| **Typed exceptions** | `WeightTransportError`, `SQSError`, `AggregationError`, `RoundTimeoutError` for granular error handling |

---

## Installation

```bash
pip install -e .
```

Dependencies: `numpy`, `pandas`, `requests`, `boto3`.

---

## Core Concepts

```
Edge Client 1 ──┐                          ┌── Edge Client 1
Edge Client 2 ──┤── SQS ──► FL Server ────┤── Edge Client 2
Edge Client 3 ──┘   (WEIGHTS_SUBMIT)       └── Edge Client 3
                     (ROUND_START / MODEL_BROADCAST)
                            │
                     Backend REST API
                     (weight storage)
```

1. Server broadcasts `ROUND_START` to all client queues.
2. Each client trains locally and submits weights via `WEIGHTS_SUBMIT`.
3. Once enough clients have submitted, the server runs FedAvg and broadcasts `MODEL_BROADCAST`.
4. Repeat.

---

## Quick Start

### Step 1 — Implement the FLModel interface

```python
import numpy as np
import pandas as pd
from fl_framework import FLModel

class MyLinearModel(FLModel):
    MODEL_ID = "my_linear_v1"          # unique ID, matches server config

    def __init__(self):
        from sklearn.linear_model import Ridge
        self._model = Ridge()

    def get_weights(self) -> dict[str, np.ndarray]:
        """Return model parameters as numpy arrays."""
        return {
            "coef":      self._model.coef_,
            "intercept": np.array([self._model.intercept_]),
        }

    def set_weights(self, weights: dict[str, np.ndarray]) -> None:
        """Apply aggregated global weights."""
        self._model.coef_      = weights["coef"]
        self._model.intercept_ = weights["intercept"][0]

    def train(self, data: pd.DataFrame, **kwargs) -> dict:
        """Local training pass. Return a metadata dict."""
        X, y = data.drop(columns=["target"]), data["target"]
        self._model.fit(X, y)
        return {"samples": len(data)}

    def evaluate(self, data: pd.DataFrame) -> dict[str, float]:
        """Return a metrics dict. Must include at least one numeric key."""
        from sklearn.metrics import r2_score
        X, y = data.drop(columns=["target"]), data["target"]
        return {"r2": float(r2_score(y, self._model.predict(X)))}

    def serialize(self) -> bytes:
        import pickle
        return pickle.dumps(self._model, protocol=4)

    @classmethod
    def deserialize(cls, data: bytes) -> "MyLinearModel":
        import pickle
        inst = cls.__new__(cls)
        inst._model = pickle.loads(data)
        return inst
```

---

### Step 2 — Configure

```python
from fl_framework import FLConfig

config = FLConfig(
    backend_url             = "https://api.example.com",
    sqs_queue_url           = "https://sqs.us-east-1.amazonaws.com/123/client-1-queue",
    sqs_server_queue_url    = "https://sqs.us-east-1.amazonaws.com/123/server-queue",
    sqs_region              = "us-east-1",
    client_id               = "edge-node-1",
    model_id                = "my_linear_v1",
    min_clients_for_aggregation = 3,
    max_rounds              = 10,
    round_timeout_s         = 300,
)
```

All fields can also be set via environment variables (see table at the bottom).

---

### Step 3 — Run a Client

```python
from fl_framework import FLClient
import pandas as pd

def get_train_data():
    return pd.read_csv("my_local_data.csv")

client = FLClient(
    model         = MyLinearModel(),
    config        = config,
    train_data_fn = get_train_data,
)

client.start()           # blocking — listens for SQS events until stop()
# OR
client.start_async()     # background thread
# ...later...
client.stop()
```

---

### Step 4 — Run the Server

```python
from fl_framework import FLServer

server = FLServer(
    config = config,
    client_queue_urls = [
        "https://sqs.us-east-1.amazonaws.com/123/client-1-queue",
        "https://sqs.us-east-1.amazonaws.com/123/client-2-queue",
        "https://sqs.us-east-1.amazonaws.com/123/client-3-queue",
    ],
    on_round_complete = lambda round_, weights, metrics: print(f"Round {round_} done: {metrics}"),
)

initial_weights = MyLinearModel().get_weights()
server.start_round(round_=0, global_weights=initial_weights)
server.start()           # blocking; auto-advances rounds
```

---

### In-Process Simulation (no SQS, no backend)

Useful for research and testing — runs everything in a single Python process.

```python
from fl_framework import FLClient, FLServer, FLConfig
from fl_framework.aggregator import ClientUpdate, Aggregator
import pandas as pd

data   = pd.read_csv("data.csv")
model  = MyLinearModel()
config = FLConfig()          # defaults are fine for simulation

# --- one round manually ---
client = FLClient(model=model, config=config)
metrics = client.run_round(
    round_          = 0,
    global_weights  = model.get_weights(),
    local_epochs    = 50,
    train_data      = data,
    eval_data       = data,
)
print(metrics)   # {"r2": 0.97, "samples": 200, ...}

# --- aggregate multiple clients ---
aggregator = Aggregator()
updates = [
    ClientUpdate("c1", weights_c1, num_samples=100, metrics={"r2": 0.97}),
    ClientUpdate("c2", weights_c2, num_samples=120, metrics={"r2": 0.95}),
]
global_weights = aggregator.aggregate(updates, strategy="fedavg")
```

---

## FLModel Interface — all methods

| Method | Required | Description |
|---|---|---|
| `get_weights() → dict[str, ndarray]` | ✅ | Extract current model parameters |
| `set_weights(dict)` | ✅ | Apply received aggregated parameters |
| `train(df, **kwargs) → dict` | ✅ | Run a local training pass |
| `evaluate(df) → dict[str, float]` | ✅ | Evaluate and return metrics |
| `serialize() → bytes` | ✅ | Serialise the full model to bytes |
| `deserialize(bytes) → FLModel` | ✅ | Reconstruct from bytes (classmethod) |
| `on_round_start(round_, weights)` | ❌ | Hook called before local training |
| `on_round_end(round_, metrics)` | ❌ | Hook called after local training |

---

## Environment Variables

| Variable | Default | Description |
|---|---|---|
| `FL_BACKEND_URL` | `http://localhost:8000` | Backend REST API base URL |
| `FL_BACKEND_API_KEY` | _(none)_ | Optional Bearer token |
| `FL_SQS_QUEUE_URL` | _(empty)_ | Client's inbound SQS queue |
| `FL_SQS_SERVER_QUEUE_URL` | _(empty)_ | Server's inbound SQS queue |
| `FL_SQS_REGION` | `us-east-1` | AWS region |
| `FL_CLIENT_ID` | _(random uuid)_ | Stable identifier for this node |
| `FL_MODEL_ID` | `default` | Model identifier string |
| `FL_ROUND_TIMEOUT_S` | `300` | Seconds to wait before `RoundTimeoutError` |
| `FL_MAX_ROUNDS` | `100` | Stop after this many rounds |

---

## Aggregation Strategies

| Strategy | Behaviour | When to use |
|---|---|---|
| `"fedavg"` _(default)_ | Weighted average by `num_samples` | Standard FL — more data = more influence |
| `"uniform"` | Equal weight per client | All clients have equal trust regardless of data size |

**Note on tree models / non-averageable parameters:** Any weight key whose numpy array has `dtype=uint8` is treated as a serialised blob. The framework performs *selection* (picks the highest-weight client's blob) rather than averaging, since trees are not additively composable.
