Metadata-Version: 2.4
Name: torchbridge-ml
Version: 0.5.22
Summary: Hardware abstraction layer for PyTorch across NVIDIA, AMD, Trainium, and TPU backends
Author: TorchBridge Team
License: MIT
Project-URL: Homepage, https://github.com/CloudlyIO/torchbridge
Project-URL: Documentation, https://torchbridge.readthedocs.io
Project-URL: Repository, https://github.com/CloudlyIO/torchbridge
Project-URL: Bug Tracker, https://github.com/CloudlyIO/torchbridge/issues
Keywords: pytorch,hardware-abstraction,multi-backend,cuda,amd,trainium,tpu,machine-learning,deep-learning
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch<3.0.0,>=2.0.0
Requires-Dist: numpy<3.0.0,>=1.21.0
Requires-Dist: pybind11<3.0.0,>=2.10.0
Provides-Extra: dev
Requires-Dist: pytest<9.0.0,>=7.0.0; extra == "dev"
Requires-Dist: pytest-cov<6.0.0,>=4.0.0; extra == "dev"
Requires-Dist: pytest-xdist<4.0.0,>=3.0.0; extra == "dev"
Requires-Dist: pytest-benchmark<5.0.0,>=4.0.0; extra == "dev"
Requires-Dist: hypothesis<7.0.0,>=6.0.0; extra == "dev"
Requires-Dist: matplotlib<4.0.0,>=3.5.0; extra == "dev"
Requires-Dist: seaborn<1.0.0,>=0.11.0; extra == "dev"
Requires-Dist: jupyter<2.0.0,>=1.0.0; extra == "dev"
Requires-Dist: tensorboard<3.0.0,>=2.9.0; extra == "dev"
Requires-Dist: ruff<1.0.0,>=0.4.0; extra == "dev"
Requires-Dist: mypy<2.0.0,>=1.0.0; extra == "dev"
Requires-Dist: pre-commit<4.0.0,>=3.5.0; extra == "dev"
Requires-Dist: bandit[toml]<2.0.0,>=1.7.0; extra == "dev"
Provides-Extra: all
Requires-Dist: transformers<5.0.0,>=4.35.0; extra == "all"
Requires-Dist: datasets<3.0.0,>=2.14.0; extra == "all"
Requires-Dist: tokenizers<1.0.0,>=0.14.0; extra == "all"
Requires-Dist: triton<4.0.0,>=2.1.0; extra == "all"
Requires-Dist: flash-attn<3.0.0,>=2.3.0; extra == "all"
Requires-Dist: accelerate<1.0.0,>=0.24.0; extra == "all"
Provides-Extra: cloud
Requires-Dist: boto3<2.0.0,>=1.28.0; extra == "cloud"
Requires-Dist: google-cloud-storage<3.0.0,>=2.10.0; extra == "cloud"
Requires-Dist: azure-storage-blob<13.0.0,>=12.17.0; extra == "cloud"
Requires-Dist: kubernetes<32.0.0,>=27.2.0; extra == "cloud"
Provides-Extra: serving
Requires-Dist: fastapi<1.0.0,>=0.103.0; extra == "serving"
Requires-Dist: uvicorn[standard]<1.0.0,>=0.23.0; extra == "serving"
Requires-Dist: torchserve<1.0.0,>=0.8.0; extra == "serving"
Requires-Dist: gradio<6.0.0,>=3.45.0; extra == "serving"
Requires-Dist: streamlit<2.0.0,>=1.27.0; extra == "serving"
Provides-Extra: monitoring
Requires-Dist: prometheus-client<1.0.0,>=0.17.0; extra == "monitoring"
Requires-Dist: wandb<1.0.0,>=0.16.0; extra == "monitoring"
Requires-Dist: tensorboard<3.0.0,>=2.14.0; extra == "monitoring"
Requires-Dist: mlflow<3.0.0,>=2.7.0; extra == "monitoring"
Requires-Dist: optuna<4.0.0,>=3.4.0; extra == "monitoring"
Provides-Extra: benchmark
Requires-Dist: memory-profiler<1.0.0,>=0.61.0; extra == "benchmark"
Requires-Dist: py-spy<1.0.0,>=0.3.14; extra == "benchmark"
Requires-Dist: torch-tb-profiler<1.0.0,>=0.4.0; extra == "benchmark"
Requires-Dist: psutil<6.0.0,>=5.9.0; extra == "benchmark"
Requires-Dist: gpustat<2.0.0,>=1.1.0; extra == "benchmark"
Provides-Extra: tracing
Requires-Dist: opentelemetry-api<2.0.0,>=1.20.0; extra == "tracing"
Provides-Extra: docs
Requires-Dist: sphinx<8.0.0,>=7.0.0; extra == "docs"
Requires-Dist: furo<2026.0.0,>=2024.1.0; extra == "docs"
Requires-Dist: myst-parser<4.0.0,>=2.0.0; extra == "docs"
Requires-Dist: sphinx-autobuild<2026.0.0,>=2024.1.0; extra == "docs"
Dynamic: license-file

# TorchBridge

**Your PyTorch code is locked to one GPU vendor.** CUDA calls, NCCL hardcoding, vendor-specific precision tricks -- they break the moment you switch hardware. TorchBridge is a hardware abstraction layer that makes your models run on NVIDIA, AMD, Trainium, and TPU without code changes, and **validates that outputs match across backends**.

[![Version](https://img.shields.io/pypi/v/torchbridge-ml?label=version&color=green)](./CHANGELOG.md) [![Tests](https://img.shields.io/badge/tests-1%2C464%20passed-blue)](./docs/reference/hardware-matrix.md) [![Cloud GPU](https://img.shields.io/badge/cloud%20GPU-9%2F9%20passed-brightgreen)](./docs/reference/cloud-validation.md) [![AWS A10G](https://img.shields.io/badge/AWS%20A10G-PASS-brightgreen)](./docs/reference/cloud-validation.md) [![GCP T4](https://img.shields.io/badge/GCP%20T4-PASS-brightgreen)](./docs/reference/cloud-validation.md) [![Python](https://img.shields.io/badge/python-3.10%2B-blue)](https://python.org) [![PyTorch](https://img.shields.io/badge/pytorch-2.0%2B-orange)](https://pytorch.org)

## What is TorchBridge?

PyTorch lets you build models. TorchBridge lets you run them **anywhere**.

Most teams write hardware-specific code -- CUDA calls for NVIDIA, ROCm setup for AMD, NeuronX setup for Trainium, XLA boilerplate for TPU. When the hardware changes, the code breaks. TorchBridge eliminates that problem with a **unified API** that detects your hardware and adapts automatically.

```
Your model code
      |
  TorchBridge HAL
      |
  +---------+---------+-----------+---------+
  | NVIDIA  |   AMD   | Trainium  |   TPU   |
  | CUDA    |  ROCm   |  NeuronX  |   XLA   |
  +---------+---------+-----------+---------+
```

**What it does:**
- **Backend detection** -- automatically identifies available accelerators
- **Vendor adapters** -- translates unified API calls to vendor-specific operations
- **Precision management** -- handles FP32/FP16/BF16/FP8 across backends
- **Memory optimization** -- gradient checkpointing, activation offloading, memory pooling
- **Checkpoint portability** -- save on one backend, load on another
- **Distributed training** -- tensor/pipeline/data parallelism across backend types

## Quick Start

```bash
pip install torchbridge-ml

# Verify
python3 -c "import torchbridge; print(f'TorchBridge v{torchbridge.__version__} ready')"
```

For development:

```bash
git clone https://github.com/CloudlyIO/torchbridge.git
cd torchbridge
pip install -e ".[dev]"
```

### Detect Hardware

```python
from torchbridge.backends import BackendFactory, detect_best_backend

backend_type = detect_best_backend()  # NVIDIA, AMD, Trainium, TPU, or CPU
backend = BackendFactory.create(backend_type)
print(backend.get_device_info())
```

### Run on Any Backend

```python
import torch
from torchbridge import TorchBridgeConfig, UnifiedManager

config = TorchBridgeConfig.for_training()
manager = UnifiedManager(config)

model = torch.nn.Sequential(
    torch.nn.Linear(768, 3072),
    torch.nn.GELU(),
    torch.nn.Linear(3072, 768),
)

optimized_model = manager.optimize(model)
```

### Validate

```python
from torchbridge import UnifiedValidator

validator = UnifiedValidator()
results = validator.validate_model(optimized_model, input_shape=(1, 768))
print(f"Validation: {results.passed}/{results.total_tests} tests passed")
```

## Supported Backends

| Backend | Hardware | Precision | Status |
|---------|----------|-----------|--------|
| **NVIDIA** | B200, H100, H200, A100, L4, T4 | FP4, FP8, BF16, FP16, FP32 | Production |
| **AMD** | MI350X, MI325X, MI300X, MI200 | FP4, FP8, BF16, FP16, FP32 | Production |
| **Trainium** | Trn1, Trn2, Trn3 (AWS NeuronX) | BF16, FP16, FP32 | Supported |
| **TPU** | v4, v5e, v5p, v6e, v7 | BF16, FP32 | Production |
| **CPU** | x86, ARM (Apple Silicon) | FP32, BF16 | Fallback |

See [Hardware Matrix](./docs/reference/hardware-matrix.md) for full details.

## Key Features

### Backend Detection and Adaptation
Automatically identifies available hardware and selects the optimal backend. No code changes needed when moving between GPU vendors or cloud providers.

### Vendor Adapters
Each backend implements a common `BaseBackend` interface. Your code calls `manager.optimize(model)` and the correct vendor-specific operations execute underneath -- CUDA on NVIDIA, HIP on AMD, NeuronX on Trainium, XLA on TPU.

### Precision Management
Configure precision once. TorchBridge handles the details per backend -- FP8 on H100, BF16 where supported, FP16 as fallback. Mixed-precision training with `torch.amp` autocast works across all backends.

### Memory Optimization
Gradient checkpointing, activation offloading, optimizer state sharding, and memory pooling. These work consistently whether you're on a single GPU or a multi-node cluster.

### Checkpoint Portability
Save a checkpoint on NVIDIA hardware, load it on AMD, Trainium, or TPU. TorchBridge handles device mapping and dtype conversion.

### Distributed Training
Tensor parallelism, pipeline parallelism, and FSDP with a unified API. The same distributed training script runs on NVIDIA DGX, AMD Instinct, Trainium instances, or TPU pods.

## Code Examples

### Backend-Agnostic Training

```python
import torch
from torchbridge.backends import BackendFactory, detect_best_backend

backend = BackendFactory.create(detect_best_backend())
device = backend.device

model = YourModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Use PyTorch native AMP -- works on any backend
scaler = torch.amp.GradScaler(device.type)
for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)
    with torch.amp.autocast(device.type):
        loss = criterion(model(inputs), targets)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
```

### Hardware Capability Queries

```python
from torchbridge.backends.nvidia import NVIDIABackend

nvidia = NVIDIABackend()
print(nvidia.get_device_info())  # GPU model, compute capability, memory
print(nvidia.supports_fp8())     # True on H100+
```

### Cross-Backend Model Export

```python
from torchbridge.deployment import export_to_torchscript, export_to_onnx, export_to_safetensors

sample_input = torch.randn(1, 768)

export_to_torchscript(model, output_path="model.pt", sample_input=sample_input)
export_to_onnx(model, output_path="model.onnx", sample_input=sample_input)
export_to_safetensors(model, output_path="model.safetensors")
```

## Project Structure

```
src/torchbridge/
├── backends/          # Vendor-specific backend implementations
│   ├── nvidia/        #   NVIDIA CUDA backend
│   ├── amd/           #   AMD ROCm backend
│   ├── trainium/      #   AWS Trainium/NeuronX backend
│   └── tpu/           #   Google TPU/XLA backend
├── hardware/          # Hardware detection and abstraction
├── precision/         # FP8 training and precision management
├── attention/         # Attention mechanisms (unified API)
├── advanced_memory/   # Memory optimization strategies
├── distributed_scale/ # Distributed training
├── deployment/        # Model export and serving
├── monitoring/        # Metrics, logging, health checks
├── optimizations/     # Optimization patterns and strategies
├── core/              # Core config, management, optimized layers
├── cli/               # Command-line tools
├── models/            # Model implementations
├── mixture_of_experts/ # MoE layer support
├── validation/        # Cross-backend validation
└── utils/             # Utilities and profiling
```

## Cloud GPU Validation

All 5 use cases validated on real GPU hardware across AWS and GCP:

| Use Case | AWS A10G | GCP L4 | Description |
|----------|----------|--------|-------------|
| Export Pipeline | PASS | PASS | TorchScript, ONNX, SafeTensors export with validation |
| LLM Optimization | PASS | PASS | Qwen3/DeepSeek optimization with backend-specific tuning |
| CI/CD Validation | PASS | PASS | Diagnostics, benchmarks, cross-backend checks |
| Backend Training | PASS | PASS | AMP training with auto backend detection |
| Cross-Backend Validation | PASS | PASS | Model, hardware, config, and output consistency |

**Platforms tested:**
- **AWS g5.xlarge** -- NVIDIA A10G 24GB, PyTorch 2.9.1+cu130
- **GCP n1-standard-4** -- NVIDIA T4 16GB, PyTorch 2.7.1+cu128

See [full validation report](./docs/reference/cloud-validation.md) for detailed benchmarks and results.

## Quality

- **1,464 tests** passing across all modules
- **0 ruff violations** -- clean linting
- **0 mypy errors** -- full type coverage
- **Cloud validated** on NVIDIA A10G (AWS), L4 (GCP), and AMD MI300X -- 5/5 use cases pass
- **Cross-platform** tested on macOS, Linux, AWS, GCP, AMD Developer Cloud

```bash
python3 -m pytest tests/ -q
ruff check src/ tests/
```

## Use Cases

**Cross-vendor training** -- Train on NVIDIA in the cloud, fine-tune on AMD on-prem, deploy on Trainium or TPU. Same code throughout.

**Cost optimization** -- Switch between cloud GPU types based on spot pricing without rewriting training scripts.

**Hardware migration** -- Move from one GPU vendor to another without a code rewrite.

**Research portability** -- Share models and training code that colleagues can run on whatever hardware they have.

## Documentation

| Document | Description |
|----------|-------------|
| [Installation](./docs/getting_started/installation.md) | Setup and requirements |
| [Quick Start](./docs/getting_started/quickstart.md) | First steps with TorchBridge |
| [Troubleshooting](./docs/getting_started/troubleshooting.md) | Common issues and fixes |
| [Backends Overview](./docs/backends/overview.md) | How the HAL works |
| [Backend Selection](./docs/guides/backend-selection.md) | Choosing the right backend |
| [Hardware Setup](./docs/guides/hardware-setup.md) | Driver and toolkit installation |
| [Distributed Training](./docs/guides/distributed-training.md) | Multi-GPU and multi-node |
| [Deployment](./docs/guides/deployment.md) | Export, serve, containerize |
| [CLI Reference](./docs/guides/cli.md) | Command-line tools |
| [Hardware Matrix](./docs/reference/hardware-matrix.md) | Full hardware support table |
| [Contributing](./CONTRIBUTING.md) | Development and contribution guide |
| [Changelog](./CHANGELOG.md) | Version history |

## License

See LICENSE file for licensing details.
