Metadata-Version: 2.4
Name: torchlitex
Version: 0.1.9
Summary: Tiny DDP training toolkit for quick-launch distributed training loops.
Author: Torchlitex Authors
License: MIT
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0; extra == "dev"
Provides-Extra: wandb
Requires-Dist: wandb>=0.16; extra == "wandb"

# torchlitex

Tiny DDP launcher + trainer that exists because PyTorch 2.x still thinks we enjoy 400 lines of torchrun boilerplate and cryptic NCCL errors. This trims it to ~20 lines and keeps `fork` happy on Vast.

## Why not just torchrun?
- You like your code more than the 17 environment variables PyTorch asks you to memorize.
- `torchrun` still feels like a 2010 MPI cosplay.
- You want fork-based spawn that doesn’t randomly faceplant on Vast.
- You want a one-function launcher + trainer, not a CLI maze.

## Install
```bash
pip install -e .
# or with wandb logging
pip install -e .[wandb]
```

## Quick start
```python
from torchlitex.launcher import launch, DistributedConfig
from torchlitex.trainer import Trainer
from torch import nn, optim
import torch

def train_fn(rank, world_size, batch_size, epochs):
    dataset = MyDataset(...)
    model = MyModel(...)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=3e-4)

    trainer = Trainer(
        model=model,
        dataset=dataset,
        loss_fn=loss_fn,
        optimizer=optimizer,
        grad_clip_norm=1.0,
        log_every=10,
    )
    trainer.ddp_train_loop(rank, world_size, batch_size=batch_size, epochs=epochs, ckpt_path="ckpt.pt")

if __name__ == "__main__":
    cfg = DistributedConfig(gpus=8)  # backend auto-switches nccl/gloo
    launch(train_fn, cfg, batch_size=64, epochs=20)
```

## Features
- Fork-first DDP launcher (no torchrun, no elastic).
- Auto backend: `nccl` when CUDA exists, `gloo` when you're on a laptop/CI.
- Trainer: AMP toggle, grad clipping, gradient accumulation, microbatching, optional schedulers, eval hook, callbacks, and EMA support.
- **Rich callback system** with access to model outputs, inputs, targets for custom progress bars and visualizations.
- **Built-in validation loop** with per-epoch val loss and sample outputs for reconstruction visualization.
- Optional wandb init + logging on rank0 (install with `.[wandb]`).
- DistributedSampler + DataLoader defaults that just work.
- Start method auto-picks `spawn` when CUDA is present (avoids forked-CUDA init errors); override to `fork` if you really want it.
- Checkpoint utilities that handle optimizer/scaler safely.
- Rank-aware logging that doesn't spam.

## Callbacks

The trainer fires rich callbacks with access to model outputs, inputs, and targets. This enables custom progress bars, reconstruction visualization, and more.

### Callback hooks

| Hook | Parameters |
|------|------------|
| `on_train_start` | `trainer`, `rank` |
| `on_epoch_start` | `epoch`, `num_batches`, `trainer`, `rank` |
| `on_batch_start` | `epoch`, `batch_idx`, `num_batches`, `trainer`, `rank` |
| `on_step_end` | `epoch`, `step`, `batch_idx`, `num_batches`, `trainer`, `rank`, `loss`, `logits`, `inputs`, `targets` |
| `on_batch_end` | `epoch`, `batch_idx`, `num_batches`, `trainer`, `rank`, `loss` |
| `on_val_end` | `epoch`, `val_loss`, `trainer`, `rank`, `sample_logits`, `sample_inputs`, `sample_targets` |
| `on_epoch_end` | `epoch`, `loss`, `val_loss`, `best`, `trainer`, `rank` |
| `on_train_end` | `trainer`, `rank` |

### Example: tqdm progress bar
```python
from tqdm import tqdm

class TQDMCallback:
    def __init__(self):
        self.pbar = None

    def on_epoch_start(self, epoch, num_batches, rank, **_):
        if rank == 0:
            self.pbar = tqdm(total=num_batches, desc=f"Epoch {epoch}")

    def on_batch_end(self, rank, **_):
        if rank == 0 and self.pbar:
            self.pbar.update(1)

    def on_epoch_end(self, rank, **_):
        if rank == 0 and self.pbar:
            self.pbar.close()
```

### Example: reconstruction visualization (autoencoders)
```python
import matplotlib.pyplot as plt

class ReconCallback:
    def on_val_end(self, epoch, sample_inputs, sample_logits, rank, **_):
        if rank != 0 or sample_inputs is None:
            return
        # For an autoencoder: inputs are original, logits are reconstructed
        orig = sample_inputs[0].cpu()  # first sample
        recon = sample_logits[0].cpu()
        fig, axes = plt.subplots(1, 2)
        axes[0].imshow(orig.permute(1, 2, 0)); axes[0].set_title("Original")
        axes[1].imshow(recon.permute(1, 2, 0)); axes[1].set_title("Reconstructed")
        plt.savefig(f"recon_epoch_{epoch}.png")
        plt.close()
```

### Example: validation dataset
```python
trainer = Trainer(
    model=model,
    dataset=train_dataset,
    val_dataset=val_dataset,  # built-in val loop
    loss_fn=loss_fn,
    optimizer=optimizer,
    callbacks=[TQDMCallback(), ReconCallback()],
)
```

## Testing levels
- Level 1: CPU unit tests (no dist).
- Level 2: CPU DDP (`backend="gloo"`, world_size=2) to validate spawn/env/sampler.
- Level 3: Single GPU (`gpus=1`) for end-to-end DDP path.
- Level 4: Real multi-GPU (same code, just crank `gpus`).

## PyTorch 2.x roast (lightly toasted)
- DDP config still feels like “choose your own adventure” but every page ends with NCCL complaining.
- torch.distributed docs read like a treasure map; the treasure is another flag.
- “Just use torchrun” is 2020’s “have you tried turning it off and on again?”

torchlitex keeps the good bits of torch 2.x (SDPA, compile, etc.) and sidesteps the distributed busywork. Use it, ship models, spend less time appeasing the NCCL spirits.
