Metadata-Version: 2.4
Name: chorus-fl
Version: 0.1.0
Summary: Federated LoRA adapter aggregation framework — exact aggregation with FedEx-LoRA
Project-URL: Homepage, https://github.com/varmabudharaju/chorus
Project-URL: Repository, https://github.com/varmabudharaju/chorus
Project-URL: Documentation, https://github.com/varmabudharaju/chorus#readme
Project-URL: Changelog, https://github.com/varmabudharaju/chorus/blob/main/CHANGELOG.md
Project-URL: Issues, https://github.com/varmabudharaju/chorus/issues
Author: Chorus Contributors
License-Expression: Apache-2.0
License-File: LICENSE
Keywords: aggregation,federated-learning,fedex-lora,fine-tuning,llm,lora,peft
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: click>=8.1.0
Requires-Dist: fastapi>=0.104.0
Requires-Dist: httpx>=0.25.0
Requires-Dist: numpy>=1.24.0
Requires-Dist: python-multipart>=0.0.6
Requires-Dist: rich>=13.0.0
Requires-Dist: safetensors>=0.4.0
Requires-Dist: torch>=2.0.0
Requires-Dist: uvicorn[standard]>=0.24.0
Requires-Dist: websockets>=12.0
Provides-Extra: all
Requires-Dist: accelerate>=0.25.0; extra == 'all'
Requires-Dist: build; extra == 'all'
Requires-Dist: datasets>=2.16.0; extra == 'all'
Requires-Dist: opacus>=1.4.0; extra == 'all'
Requires-Dist: peft>=0.7.0; extra == 'all'
Requires-Dist: pytest-asyncio>=0.23.0; extra == 'all'
Requires-Dist: pytest-httpx>=0.28.0; extra == 'all'
Requires-Dist: pytest>=7.0.0; extra == 'all'
Requires-Dist: ruff>=0.1.0; extra == 'all'
Requires-Dist: transformers>=4.36.0; extra == 'all'
Provides-Extra: dev
Requires-Dist: build; extra == 'dev'
Requires-Dist: pytest-asyncio>=0.23.0; extra == 'dev'
Requires-Dist: pytest-httpx>=0.28.0; extra == 'dev'
Requires-Dist: pytest>=7.0.0; extra == 'dev'
Requires-Dist: ruff>=0.1.0; extra == 'dev'
Provides-Extra: peft
Requires-Dist: accelerate>=0.25.0; extra == 'peft'
Requires-Dist: datasets>=2.16.0; extra == 'peft'
Requires-Dist: peft>=0.7.0; extra == 'peft'
Requires-Dist: transformers>=4.36.0; extra == 'peft'
Provides-Extra: privacy
Requires-Dist: opacus>=1.4.0; extra == 'privacy'
Description-Content-Type: text/markdown

# Chorus

[![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
[![PyPI version](https://img.shields.io/pypi/v/chorus-fl.svg)](https://pypi.org/project/chorus-fl/)

**Federated LoRA fine-tuning with mathematically exact aggregation.**

Chorus is a framework for federated fine-tuning of large language models using LoRA adapters. Multiple clients train on their private data, submit adapter deltas to a central server, and receive back aggregated improvements — without sharing any raw data.

The key insight: standard FedAvg is **broken for LoRA** because `avg(B @ A) != avg(B) @ avg(A)`. Chorus implements [FedEx-LoRA](https://arxiv.org/abs/2501.03075) (ACL/ICLR 2025), which provides **exact** federated aggregation by tracking and folding SVD residuals.

## How It Works

```
Client 1 (private data)          Aggregation Server           Client 2 (private data)
┌─────────────────────┐       ┌─────────────────────┐       ┌─────────────────────┐
│  1. Train LoRA      │       │                     │       │  1. Train LoRA      │
│  2. Submit delta  ──┼──POST─┼→ Collect deltas     │←─POST─┼── 2. Submit delta   │
│                     │       │  FedEx-LoRA agg     │       │                     │
│                     │       │  Fold residuals     │       │                     │
│  3. Pull updated  ←─┼──GET──┼─ Serve result       │──GET──┼→ 3. Pull updated    │
│  4. Repeat          │       │  WS: round_complete │       │  4. Repeat          │
└─────────────────────┘       └─────────────────────┘       └─────────────────────┘
```

## Installation

```bash
pip install chorus-fl
```

With optional dependencies:

```bash
# For local LoRA training (PEFT + Transformers)
pip install "chorus-fl[peft]"

# For differential privacy
pip install "chorus-fl[privacy]"

# Everything
pip install "chorus-fl[all]"
```

From source:

```bash
git clone https://github.com/varmabudharaju/chorus.git
cd chorus
pip install -e ".[dev]"
```

## Quick Start

### 1. Start the server

```bash
chorus server --model meta-llama/Llama-3.2-3B --min-deltas 3
```

### 2. Submit adapters from clients

```python
from chorus import ChorusClient

client = ChorusClient(
    server="http://localhost:8080",
    model_id="meta-llama/Llama-3.2-3B",
)

# After your local LoRA training...
client.submit_delta(adapter_path="./my-adapter")

# Pull the aggregated global adapter
client.pull_latest(output_path="./updated-adapter")

client.close()
```

### 3. Run a simulation (no server needed)

```bash
# Compare FedAvg vs FedEx-LoRA
chorus simulate --clients 10 --rounds 5 --compare
```

## Why FedEx-LoRA?

LoRA decomposes weight updates as `W = B @ A` (two low-rank matrices). When you naively average across clients:

```
avg(B_i @ A_i)  !=  avg(B_i) @ avg(A_i)
```

**FedAvg produces mathematically inexact aggregation for LoRA.** FedEx-LoRA fixes this:

1. Computes the exact weighted average of full-rank products `B_i @ A_i`
2. Uses SVD to get the optimal rank-r approximation (Eckart-Young theorem)
3. Tracks the residual between exact and approximate results
4. Folds residuals into base weights, making the combined result **exact**

## CLI Reference

### `chorus server`

Start the aggregation server.

```bash
chorus server --model <model-id> [options]
```

| Option | Default | Description |
|--------|---------|-------------|
| `--model` | *required* | Model ID (e.g. `meta-llama/Llama-3.2-3B`) |
| `--port` | `8080` | Port to listen on |
| `--host` | `0.0.0.0` | Host to bind to |
| `--data-dir` | `./chorus_data` | Data directory for storage |
| `--strategy` | `fedex-lora` | Aggregation strategy (`fedavg` or `fedex-lora`) |
| `--min-deltas` | `2` | Minimum deltas before aggregation triggers |
| `--dp-epsilon` | *disabled* | Server-side differential privacy epsilon |
| `--api-key` | *disabled* | API key for auth (can specify multiple times) |
| `--base-weights` | *none* | Path to base model weights (`.safetensors`) |
| `--norm-bound` | *disabled* | Max L2 norm for Byzantine defense |
| `--outlier-threshold` | *disabled* | Z-score threshold for outlier detection |
| `--rate-limit` | `0` | Max requests per minute per IP (0 = disabled) |
| `-v, --verbose` | | Verbose logging |

### `chorus submit`

Submit a LoRA adapter delta to the server.

```bash
chorus submit --server <url> --adapter <path> [options]
```

| Option | Default | Description |
|--------|---------|-------------|
| `--server` | *required* | Server URL |
| `--adapter` | *required* | Path to adapter directory or `.safetensors` file |
| `--model-id` | *auto* | Model ID (auto-detected from server) |
| `--client-id` | *auto* | Client identifier |
| `--round-id` | *current* | Target round |
| `--dp-epsilon` | *disabled* | Local DP epsilon |
| `--dataset-size` | *none* | Dataset size for weighted aggregation |
| `--api-key` | *none* | API key for authentication |

### `chorus pull`

Pull the latest aggregated adapter from the server.

```bash
chorus pull --server <url> --output <path> [options]
```

### `chorus train`

Run the full federated training loop (train -> submit -> wait -> pull -> repeat).

```bash
chorus train --server <url> --model <hf-model-id> --dataset <dataset> [options]
```

| Option | Default | Description |
|--------|---------|-------------|
| `--server` | *required* | Server URL |
| `--model` | *required* | HuggingFace model ID |
| `--dataset` | *required* | HuggingFace dataset or local path |
| `--rounds` | *infinite* | Number of training rounds |
| `--lora-rank` | `16` | LoRA rank |
| `--max-steps` | `-1` | Max training steps per round (-1 = full epoch) |
| `--dp-epsilon` | *disabled* | Local DP epsilon |

### `chorus simulate`

Run a simulated federation with synthetic data.

```bash
chorus simulate --clients 10 --rounds 5 --compare
```

### `chorus status`

Show the current status of a Chorus server.

```bash
chorus status --server <url>
```

### `chorus export`

Export a merged model (base + aggregated adapter) ready for deployment.

```bash
chorus export --server <url> --model <hf-model-id> --output ./merged/
```

## Python SDK

### `ChorusClient`

```python
from chorus import ChorusClient

client = ChorusClient(
    server="http://localhost:8080",
    model_id="my-model",
    client_id="client-1",          # optional, auto-generated if omitted
    api_key="secret",              # optional, for authenticated servers
    dp_epsilon=1.0,                # optional, local differential privacy
    dp_delta=1e-5,                 # optional, DP delta parameter
    dp_max_norm=1.0,               # optional, DP clipping norm
    timeout=120.0,                 # optional, HTTP timeout in seconds
)

# Check server status
status = client.status()

# Submit a trained LoRA adapter
result = client.submit_delta(
    adapter_path="./my-adapter",   # PEFT adapter dir or .safetensors
    round_id=None,                 # None = current round
    dataset_size=5000,             # for weighted aggregation
)

# Submit raw tensors directly
result = client.submit_tensors(tensors={"layer.lora_A.weight": tensor_a, ...})

# Pull the latest aggregated adapter
client.pull_latest(output_path="./updated-adapter")

# Pull a specific round
client.pull_round(round_id=3, output_path="./round-3-adapter")

# Export merged model (requires chorus[peft])
client.export_model(
    base_model="meta-llama/Llama-3.2-3B",
    output_dir="./merged-model",
)

# Full training loop (requires chorus[peft])
client.train_loop(
    trainer=my_trainer,            # LoRATrainer instance
    rounds=5,
)

# Listen for round completion via WebSocket
for event in client.listen():
    print(f"Round {event['round_id']} complete!")

client.close()
# Or use as context manager:
# with ChorusClient(...) as client:
#     ...
```

## API Endpoints

| Method | Path | Description |
|--------|------|-------------|
| `GET` | `/health` | Health check (public, includes `ws_clients` count) |
| `GET` | `/models/{id}/status` | Round state, delta count, latest round |
| `POST` | `/rounds/{round_id}/deltas` | Submit LoRA delta (`dataset_size` param for weighting) |
| `GET` | `/models/{id}/latest` | Download latest aggregated adapter |
| `GET` | `/models/{id}/rounds/{round_id}` | Download round-specific adapter |
| `POST` | `/models/{id}/base-weights` | Upload base model weights |
| `GET` | `/models/{id}/base-weights` | Download current base weights |
| `GET` | `/models/{id}/checkpoint` | Download base + adapter merged checkpoint |
| `WS` | `/ws/{client_id}` | WebSocket for live round notifications |

## Architecture

```
chorus/
├── patterns.py              # Shared LoRA key patterns
├── exceptions.py            # Exception hierarchy (ChorusError, etc.)
├── server/
│   ├── app.py               # FastAPI endpoints + auth + async aggregation
│   ├── aggregation.py       # FedAvg + FedEx-LoRA (SVD) + Byzantine defenses
│   ├── storage.py           # Filesystem storage for deltas, base weights, round state
│   ├── weight_manager.py    # Residual folding into base weights
│   ├── ws.py                # WebSocket connection manager
│   └── privacy.py           # Gaussian DP mechanism + L2 clipping
├── client/
│   ├── sdk.py               # ChorusClient (submit, pull, listen, train_loop, export)
│   ├── trainer.py           # LoRATrainer wrapper for HF PEFT
│   └── delta.py             # LoRA matrix extraction from PEFT adapters
├── cli/
│   └── main.py              # Click CLI with error handling
└── simulate/
    └── runner.py            # Synthetic multi-client federation runner
```

## Security Features

Chorus includes several security mechanisms for production deployments:

- **Authentication** — Bearer token auth via `--api-key` (supports multiple keys)
- **Differential privacy** — Gaussian DP with global L2 clipping at both client and server level
- **Byzantine defenses** — L2 norm bounding (`--norm-bound`) and z-score outlier detection (`--outlier-threshold`) reject malicious or corrupted deltas
- **Rate limiting** — Per-IP request throttling via `--rate-limit`
- **safetensors only** — All weight serialization uses safetensors format (no pickle deserialization)

> **Note:** Chorus serves over HTTP. For production, deploy behind a TLS-terminating reverse proxy (nginx, Caddy, etc.).

## Aggregation Strategies

| Strategy | Exact? | Description |
|----------|--------|-------------|
| `fedex-lora` (default) | Yes | SVD-based exact aggregation with residual folding |
| `fedavg` | No | Naive independent averaging of A and B matrices |

## Configuration Examples

### Secure production server

```bash
chorus server \
  --model meta-llama/Llama-3.2-3B \
  --min-deltas 5 \
  --api-key $SECRET_KEY_1 \
  --api-key $SECRET_KEY_2 \
  --dp-epsilon 2.0 \
  --norm-bound 10.0 \
  --outlier-threshold 3.0 \
  --rate-limit 60 \
  --base-weights ./base-model.safetensors
```

### Client with local DP

```python
client = ChorusClient(
    server="http://chorus.internal:8080",
    model_id="meta-llama/Llama-3.2-3B",
    api_key="my-secret-key",
    dp_epsilon=1.0,        # strong local DP
    dp_max_norm=1.0,       # clip before noising
)
```

### Full training loop

```bash
chorus train \
  --server http://localhost:8080 \
  --model meta-llama/Llama-3.2-3B \
  --dataset wikitext \
  --rounds 10 \
  --lora-rank 16
```

## Examples

See the [`examples/`](examples/) directory:

- **[`quickstart.py`](examples/quickstart.py)** — Basic 2-client workflow with synthetic adapters
- **[`health_metrics/federated_health.py`](examples/health_metrics/federated_health.py)** — Multi-hospital federated training simulation with DP

## Development

```bash
git clone https://github.com/varmabudharaju/chorus.git
cd chorus
pip install -e ".[dev]"

# Run tests (165 tests)
pytest tests/ -v

# Run benchmarks
python benchmarks/benchmark.py
```

## Contributing

Contributions are welcome! Here's how to get started:

1. Fork the repository
2. Create a feature branch (`git checkout -b feature/my-feature`)
3. Make your changes and add tests
4. Run the test suite (`pytest tests/ -v`)
5. Submit a pull request

Please open an issue first to discuss significant changes.

## License

Apache 2.0 — see [LICENSE](LICENSE) for the full text.
