Metadata-Version: 2.4
Name: torchbridge-ml
Version: 0.5.45
Summary: Cross-backend validation and configuration intelligence for PyTorch — NVIDIA, AMD, Trainium, and TPU
Author: TorchBridge Team
License: MIT
Project-URL: Homepage, https://github.com/CloudlyIO/torchbridge
Project-URL: Documentation, https://torchbridge.readthedocs.io
Project-URL: Repository, https://github.com/CloudlyIO/torchbridge
Project-URL: Bug Tracker, https://github.com/CloudlyIO/torchbridge/issues
Keywords: pytorch,hardware-abstraction,multi-backend,cuda,amd,trainium,tpu,machine-learning,deep-learning
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch<3.0.0,>=2.0.0
Requires-Dist: numpy<3.0.0,>=1.21.0
Provides-Extra: dev
Requires-Dist: pytest<9.0.0,>=7.0.0; extra == "dev"
Requires-Dist: pytest-cov<6.0.0,>=4.0.0; extra == "dev"
Requires-Dist: pytest-xdist<4.0.0,>=3.0.0; extra == "dev"
Requires-Dist: pytest-benchmark<5.0.0,>=4.0.0; extra == "dev"
Requires-Dist: hypothesis<7.0.0,>=6.0.0; extra == "dev"
Requires-Dist: matplotlib<4.0.0,>=3.5.0; extra == "dev"
Requires-Dist: seaborn<1.0.0,>=0.11.0; extra == "dev"
Requires-Dist: jupyter<2.0.0,>=1.0.0; extra == "dev"
Requires-Dist: tensorboard<3.0.0,>=2.9.0; extra == "dev"
Requires-Dist: ruff<1.0.0,>=0.4.0; extra == "dev"
Requires-Dist: mypy<2.0.0,>=1.0.0; extra == "dev"
Requires-Dist: pre-commit<4.0.0,>=3.5.0; extra == "dev"
Requires-Dist: bandit[toml]<2.0.0,>=1.7.0; extra == "dev"
Provides-Extra: all
Requires-Dist: transformers<5.0.0,>=4.35.0; extra == "all"
Requires-Dist: datasets<3.0.0,>=2.14.0; extra == "all"
Requires-Dist: tokenizers<1.0.0,>=0.14.0; extra == "all"
Requires-Dist: triton<4.0.0,>=2.1.0; extra == "all"
Requires-Dist: flash-attn<4.0.0,>=2.3.0; extra == "all"
Requires-Dist: accelerate<1.0.0,>=0.24.0; extra == "all"
Provides-Extra: cloud
Requires-Dist: boto3<2.0.0,>=1.28.0; extra == "cloud"
Requires-Dist: google-cloud-storage<3.0.0,>=2.10.0; extra == "cloud"
Requires-Dist: azure-storage-blob<13.0.0,>=12.17.0; extra == "cloud"
Requires-Dist: kubernetes<32.0.0,>=27.2.0; extra == "cloud"
Provides-Extra: serving
Requires-Dist: fastapi<1.0.0,>=0.103.0; extra == "serving"
Requires-Dist: uvicorn[standard]<1.0.0,>=0.23.0; extra == "serving"
Requires-Dist: torchserve<1.0.0,>=0.8.0; extra == "serving"
Requires-Dist: gradio<6.0.0,>=3.45.0; extra == "serving"
Requires-Dist: streamlit<2.0.0,>=1.27.0; extra == "serving"
Provides-Extra: monitoring
Requires-Dist: prometheus-client<1.0.0,>=0.17.0; extra == "monitoring"
Requires-Dist: wandb<1.0.0,>=0.16.0; extra == "monitoring"
Requires-Dist: tensorboard<3.0.0,>=2.14.0; extra == "monitoring"
Requires-Dist: mlflow<3.0.0,>=2.7.0; extra == "monitoring"
Requires-Dist: optuna<4.0.0,>=3.4.0; extra == "monitoring"
Provides-Extra: benchmark
Requires-Dist: memory-profiler<1.0.0,>=0.61.0; extra == "benchmark"
Requires-Dist: py-spy<1.0.0,>=0.3.14; extra == "benchmark"
Requires-Dist: torch-tb-profiler<1.0.0,>=0.4.0; extra == "benchmark"
Requires-Dist: psutil<6.0.0,>=5.9.0; extra == "benchmark"
Requires-Dist: gpustat<2.0.0,>=1.1.0; extra == "benchmark"
Provides-Extra: tracing
Requires-Dist: opentelemetry-api<2.0.0,>=1.20.0; extra == "tracing"
Provides-Extra: quantization
Requires-Dist: torchao<1.0.0,>=0.4.0; extra == "quantization"
Provides-Extra: checkpoint
Requires-Dist: s3fs<2026.0.0,>=2024.1.0; extra == "checkpoint"
Requires-Dist: gcsfs<2026.0.0,>=2024.1.0; extra == "checkpoint"
Requires-Dist: adlfs<2026.0.0,>=2024.1.0; extra == "checkpoint"
Provides-Extra: docs
Requires-Dist: sphinx<8.0.0,>=7.0.0; extra == "docs"
Requires-Dist: furo<2026.0.0,>=2024.1.0; extra == "docs"
Requires-Dist: myst-parser<4.0.0,>=2.0.0; extra == "docs"
Requires-Dist: sphinx-autobuild<2026.0.0,>=2024.1.0; extra == "docs"
Dynamic: license-file

# TorchBridge

**Your PyTorch code is locked to one GPU vendor.** CUDA calls, NCCL hardcoding, vendor-specific precision tricks -- they break the moment you switch hardware. TorchBridge is a cross-backend validation and configuration intelligence layer for PyTorch: it **validates that outputs match across backends** and generates optimal configurations for NVIDIA, AMD, Trainium, and TPU hardware.

[![Version](https://img.shields.io/pypi/v/torchbridge-ml?label=version&color=green)](./CHANGELOG.md) [![Tests](https://img.shields.io/badge/tests-2%2C668%20passed-blue)](./docs/reference/hardware-matrix.md) [![Cloud GPU](https://img.shields.io/badge/platforms-8%20validated%2C%206%20GPU-brightgreen)](./docs/reference/cloud-validation.md) [![AWS A10G](https://img.shields.io/badge/AWS%20A10G-PASS-brightgreen)](./docs/reference/cloud-validation.md) [![GCP T4](https://img.shields.io/badge/GCP%20T4-PASS-brightgreen)](./docs/reference/cloud-validation.md) [![H100 NVL](https://img.shields.io/badge/H100%20NVL-PASS-brightgreen)](./docs/reference/cloud-validation.md) [![MI300X](https://img.shields.io/badge/MI300X-PASS-brightgreen)](./docs/reference/cloud-validation.md) [![TPU v5e](https://img.shields.io/badge/TPU%20v5e-PASS-brightgreen)](./docs/reference/cloud-validation.md) [![Python](https://img.shields.io/badge/python-3.10%2B-blue)](https://python.org) [![PyTorch](https://img.shields.io/badge/pytorch-2.0%2B-orange)](https://pytorch.org)

## What is TorchBridge?

PyTorch lets you build models. TorchBridge lets you run them **anywhere**.

Most teams write hardware-specific code -- CUDA calls for NVIDIA, ROCm setup for AMD, NeuronX setup for Trainium, XLA boilerplate for TPU. When the hardware changes, the code breaks. TorchBridge eliminates that problem with a **unified API** that detects your hardware and adapts automatically.

```
Your model code
      |
  TorchBridge
      |
  +---------+---------+-----------+---------+
  | NVIDIA  |   AMD   | Trainium  |   TPU   |
  | CUDA    |  ROCm   |  NeuronX  |   XLA   |
  +---------+---------+-----------+---------+
```

**What it does:**
- **Backend detection** -- automatically identifies available accelerators
- **Vendor adapters** -- translates unified API calls to vendor-specific operations
- **Precision management** -- handles FP32/FP16/BF16/FP8 across backends with compatibility matrices
- **Quantization** -- backend-aware format selection with automatic fallback chains
- **Attention dispatch** -- selects the best attention kernel (FlexAttention, Flash, Triton, etc.) per hardware
- **Checkpoint portability** -- save on one backend, load on another with dtype normalization
- **Distributed config** -- generates FSDP, pipeline, and collective configs from detected topology

## Quick Start

```bash
pip install torchbridge-ml

# Verify
python3 -c "import torchbridge; print(f'TorchBridge v{torchbridge.__version__} ready')"
```

For development:

```bash
git clone https://github.com/CloudlyIO/torchbridge.git
cd torchbridge
pip install -e ".[dev]"
```

### Detect Hardware

```python
from torchbridge.backends import BackendFactory, detect_best_backend

backend_type = detect_best_backend()  # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())
```

### Run on Any Backend

```python
import torch
from torchbridge import TorchBridgeConfig, UnifiedManager

config = TorchBridgeConfig.for_training()
manager = UnifiedManager(config)

model = torch.nn.Sequential(
    torch.nn.Linear(768, 3072),
    torch.nn.GELU(),
    torch.nn.Linear(3072, 768),
)

optimized_model = manager.optimize(model)
```

### Validate

```python
from torchbridge import UnifiedValidator

validator = UnifiedValidator()
results = validator.validate_model(optimized_model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")
```

## Supported Backends

| Backend | Hardware | Precision | Status |
|---------|----------|-----------|--------|
| **NVIDIA** | B200, H100, H200, A100, L4, T4 | FP4, FP8, BF16, FP16, FP32 | Production |
| **AMD** | MI350X, MI325X, MI300X, MI200 | FP8, BF16, FP16, FP32 | Production |
| **Trainium** | Trn1, Trn2, Trn3 (AWS NeuronX) | BF16, FP16, FP32 | Supported |
| **TPU** | v4, v5e, v5p, v6e, v7 | BF16, FP32 | Production |
| **CPU** | x86, ARM (Apple Silicon) | FP32, BF16 | Fallback |

See [Hardware Matrix](./docs/reference/hardware-matrix.md) for full details.

## Key Features

### Backend Detection and Adaptation
Automatically identifies available hardware and selects the optimal backend. No code changes needed when moving between GPU vendors or cloud providers.

### Vendor Adapters
Each backend implements a common `BaseBackend` interface. Your code calls `manager.optimize(model)` and the correct vendor-specific operations execute underneath -- CUDA on NVIDIA, HIP on AMD, NeuronX on Trainium, XLA on TPU.

### Precision Management
Configure precision once. TorchBridge handles the details per backend -- FP8 on H100, BF16 where supported, FP16 as fallback. Mixed-precision training with `torch.amp` autocast works across all backends.

### Checkpoint Portability
Save a checkpoint on NVIDIA hardware, load it on AMD, Trainium, or TPU. TorchBridge handles device mapping, dtype normalization, and FP8-to-FP16 conversion via PyTorch Distributed Checkpoint (DCP).

### Distributed Configuration
Generates FSDP sharding strategies, pipeline schedules, and collective backend configs based on detected cluster topology. TorchBridge produces config objects that you pass to PyTorch's native distributed primitives -- it does not implement distributed training itself.

## Code Examples

### Backend-Agnostic Training

```python
import torch
from torchbridge.backends import BackendFactory, detect_best_backend

backend = BackendFactory.create(detect_best_backend())
device = backend.device

model = YourModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Use PyTorch native AMP -- works on any backend
scaler = torch.amp.GradScaler(device.type)
for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)
    with torch.amp.autocast(device.type):
        loss = criterion(model(inputs), targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
```

### Hardware Capability Queries

```python
from torchbridge.backends.nvidia import NVIDIABackend

nvidia = NVIDIABackend()
print(nvidia.get_device_info())  # GPU model, compute capability, memory
print(nvidia.supports_fp8())     # True on H100+
```

### Cross-Backend Model Export

```python
from torchbridge.deployment import export_to_torchscript, export_to_onnx, export_to_safetensors

sample_input = torch.randn(1, 768)

export_to_torchscript(model, output_path="model.pt", sample_input=sample_input)
export_to_onnx(model, output_path="model.onnx", sample_input=sample_input)
export_to_safetensors(model, output_path="model.safetensors")
```

## Project Structure

```
src/torchbridge/
├── backends/          # Vendor-specific backend implementations
│   ├── nvidia/        #   NVIDIA CUDA backend
│   ├── amd/           #   AMD ROCm backend
│   ├── trainium/      #   AWS Trainium/NeuronX backend
│   └── tpu/           #   Google TPU/XLA backend
├── hardware/          # Hardware detection and abstraction
├── precision/         # FP8 training and precision management
├── attention/         # Attention mechanisms (unified API)
├── advanced_memory/   # Memory optimization strategies
├── distributed_scale/ # Distributed training
├── deployment/        # Model export and serving
├── monitoring/        # Metrics, logging, health checks
├── optimizations/     # Optimization patterns and strategies
├── core/              # Core config, management, optimized layers
├── cli/               # Command-line tools
├── models/            # Model implementations
├── mixture_of_experts/ # MoE layer support
├── validation/        # Cross-backend validation
└── utils/             # Utilities and profiling
```

## Cloud Hardware Validation

Cross-backend numerical consistency validated on 8 platforms (6 real GPU/accelerator, 2 CPU-fallback†) using Qwen3-0.6B:

| Platform | Hardware | Max Diff | Cosine Sim | Latency | Status |
|----------|----------|----------|------------|---------|--------|
| AWS | NVIDIA A10G (24GB) | 1.96e-05 | 1.000001 | 41.8 ms | PASS |
| GCP | NVIDIA T4 (16GB) | 2.67e-05 | 1.000001 | 50.8 ms | PASS |
| RunPod | NVIDIA H100 NVL (100GB) | 2.29e-05 | 1.000001 | 18.8 ms | PASS |
| AMD DevCloud | AMD MI300X (192GB) | 4.82e-05 | 1.000001 | 30.0 ms | PASS |
| GCP | TPU v5e | 1.08e-01 | 0.999980 | 47.5 ms | PASS |
| Local | Apple Silicon (MPS) | 4.58e-05 | 1.000002 | 27.8 ms | PASS |
| AWS Trainium† | Trn1.2xlarge (NeuronX) | 0.00e+00 | 1.000001 | 103.3 ms (CPU) | PASS |
| AWS Inferentia2† | inf2.xlarge (NeuronX) | 0.00e+00 | 1.000001 | 321.7 ms (CPU) | PASS |

† **CPU fallback:** NeuronX SDK compilation requires quota-enabled Trn1/Inf2 instances not available in the validation environment. These rows confirm correct CPU-path behavior (max_diff = 0.00e+00 is CPU-vs-CPU, not accelerator validation). Real NeuronX validation is pending quota approval.

All GPU/accelerator backends produce semantically identical outputs (cosine similarity > 0.999).

See [full validation report](./docs/reference/cloud-validation.md) for detailed benchmarks and results.

## Quality

- **2,563 tests** collected (hardware-gated skips on non-GPU environments)
- **0 ruff violations** -- clean linting
- **0 mypy errors** -- full type coverage
- **Cloud validated** on 8 platforms (6 GPU-validated: A10G, T4, H100 NVL, MI300X, TPU v5e, MPS; 2 CPU-fallback†: Trainium, Inferentia2)
- **Cross-platform** tested on macOS, Linux, AWS, GCP, AMD Developer Cloud, RunPod

```bash
python3 -m pytest tests/ -q
ruff check src/ tests/
```

## Use Cases

**Cross-vendor training** -- Train on NVIDIA in the cloud, fine-tune on AMD on-prem, deploy on Trainium or TPU. Same code throughout.

**Cost optimization** -- Switch between cloud GPU types based on spot pricing without rewriting training scripts.

**Hardware migration** -- Move from one GPU vendor to another without a code rewrite.

**Research portability** -- Share models and training code that colleagues can run on whatever hardware they have.

## Documentation

| Document | Description |
|----------|-------------|
| [Installation](./docs/getting_started/installation.md) | Setup and requirements |
| [Quick Start](./docs/getting_started/quickstart.md) | First steps with TorchBridge |
| [Troubleshooting](./docs/getting_started/troubleshooting.md) | Common issues and fixes |
| [Backends Overview](./docs/backends/overview.md) | How the backend system works |
| [Backend Selection](./docs/guides/backend-selection.md) | Choosing the right backend |
| [Hardware Setup](./docs/guides/hardware-setup.md) | Driver and toolkit installation |
| [Distributed Training](./docs/guides/distributed-training.md) | Multi-GPU and multi-node |
| [Deployment](./docs/guides/deployment.md) | Export, serve, containerize |
| [CLI Reference](./docs/guides/cli.md) | Command-line tools |
| [Hardware Matrix](./docs/reference/hardware-matrix.md) | Full hardware support table |
| [Contributing](./CONTRIBUTING.md) | Development and contribution guide |
| [Changelog](./CHANGELOG.md) | Version history |

## License

See LICENSE file for licensing details.
