Metadata-Version: 2.4
Name: mps-bitsandbytes
Version: 0.1.1
Summary: 8-bit and 4-bit quantization for PyTorch on Apple Silicon (M1/M2/M3/M4)
Author: imperatormk
License-Expression: MIT
Project-URL: Homepage, https://github.com/mpsops/mps-bitsandbytes
Project-URL: Repository, https://github.com/mpsops/mps-bitsandbytes
Project-URL: Issues, https://github.com/mpsops/mps-bitsandbytes/issues
Keywords: quantization,8bit,4bit,apple-silicon,pytorch,mps,metal,qlora,bitsandbytes
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.0.0
Dynamic: requires-python

# MPS BitsAndBytes

8-bit quantization for PyTorch on Apple Silicon (M1/M2/M3/M4).

**50% memory savings** for storing model weights, with **no speed penalty** using smart caching.

## Features

- **Linear8bit**: Drop-in replacement for `nn.Linear` with int8 weights
- **Smart caching**: Dequantize once, run fast fp16 matmul (AMX accelerated)
- **QLoRA ready**: Perfect for fine-tuning large models on Mac
- **Pure PyTorch**: No custom kernels needed, works out of the box

## Installation

```bash
pip install mps-bitsandbytes
```

Or from source:

```bash
git clone https://github.com/mpsops/mps-bitsandbytes
cd mps-bitsandbytes
pip install -e .
```

## Quick Start

```python
import torch
from mps_bitsandbytes import Linear8bit, quantize_model

# Convert existing model to 8-bit
model = YourModel().to('mps')
model = quantize_model(model, device='mps')

# Or convert individual layers
linear_8bit = Linear8bit.from_linear(some_linear_layer)

# Use normally - same API, 50% less memory for weights
output = model(input)
```

## How It Works

1. **Storage**: Weights stored as int8 (1 byte per param vs 2 bytes for fp16)
2. **First forward**: Dequantize int8 → fp16, cache the result
3. **Subsequent forwards**: Use cached fp16 weights, fast AMX matmul

This gives you:
- **50% memory savings** on disk and when loading weights
- **Same inference speed** as fp16 (once cached)
- **Compatible with QLoRA** training

## Memory Savings

| Model Size | FP16 | INT8 | Savings |
|------------|------|------|---------|
| 7B params  | 14 GB | 7 GB | 7 GB |
| 13B params | 26 GB | 13 GB | 13 GB |
| 70B params | 140 GB | 70 GB | 70 GB |

## Configuration

```python
# Default: cache enabled (fast, uses memory during inference)
layer = Linear8bit.from_linear(linear, use_cache=True)

# Memory-constrained: no cache (slower, minimum memory)
layer = Linear8bit.from_linear(linear, use_cache=False)

# Clear cache to free memory
layer.clear_cache()
```

## QLoRA Training

```python
from mps_bitsandbytes import quantize_model
from peft import get_peft_model, LoraConfig

# Load model in 8-bit
model = AutoModelForCausalLM.from_pretrained("model_name")
model = quantize_model(model.to('mps'))

# Add LoRA adapters (these stay in fp16 for gradients)
lora_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"])
model = get_peft_model(model, lora_config)

# Train - base weights frozen in int8, LoRA in fp16
trainer.train()
```

## Benchmarks

Tested on M1 Max, batch_size=32, hidden_dim=4096:

| Method | Forward Time | Memory |
|--------|-------------|--------|
| FP16 | 1.08 ms | 100 MB |
| INT8 (cached) | 0.98 ms | 50 MB + cache |
| INT8 (no cache) | 9.65 ms | 50 MB |

## Limitations

- **First forward is slower**: Need to dequantize weights once
- **Cache uses memory**: During inference, cached fp16 weights use extra memory
- **No int8 matmul acceleration**: Apple Silicon AMX only supports fp16/fp32

For maximum memory savings during inference (no cache), use `use_cache=False`, but expect ~10x slower inference.

## Credits

- [bitsandbytes](https://github.com/bitsandbytes-foundation/bitsandbytes) - Original CUDA implementation
- [LLM.int8()](https://arxiv.org/abs/2208.07339) - Paper by Tim Dettmers et al.

## License

MIT
