Metadata-Version: 2.4
Name: mambax
Version: 0.2.0
Summary: Optimized Mamba implementation with chunk processing and ONNX export
Author: Oleg Kufa
Author-email: os.schischkin@gmail.com
Project-URL: Source Code, https://github.com/yourusername/mambax
Project-URL: Bug Tracker, https://github.com/yourusername/mambax/issues
Keywords: mamba,pytorch,deep learning
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Intended Audience :: Science/Research
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: numpy>=1.21.0
Dynamic: author
Dynamic: author-email
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: keywords
Dynamic: license-file
Dynamic: project-url
Dynamic: requires-dist
Dynamic: requires-python
Dynamic: summary

# MambaX

PyTorch implementation of the Mamba architecture with enhanced production-ready features:

1. **ONNX Export** - Full model export support for deployment
2. **Chunk Processing** - Single-forwardpass chunk handling (no token loops)
3. **CPU-First** - Optimized execution without CUDA dependencies

## Installation

Install the package directly from PyPI:

```bash
pip install mambax
```

## Key Advantages

- **Production Ready**: ONNX-compatible for serving
- **No Token Loops**: Processes entire chunks in single forward pass
- **Hardware Agnostic**: Runs equally well on CPU/GPU

## Acknowledgements
Builds upon reference work from [alxndrTL/mamba.py](https://github.com/alxndrTL/mamba.py)

## Usage

### 1. Standard Forward Pass
```python
import torch
import torch.nn as nn
from mambax import Mamba

# Initialize model
model = Mamba(
    d_model=512,
    d_inner=1024,
    d_conv=4,
    d_state=16,
    dt_rank=64,
    use_cuda=False
)

# Process full sequence
x = torch.rand(1, 128, 512)  # (batch, seq_len, dim)
output = model(x)  # single forward pass
```

### 2. Single Token Processing
```python

x_token = torch.rand(1, 1, 512)  # (batch, 1, dim)
output, new_state, new_conv = model(x_token, state_cache, conv_cache)

```

### 3. Chunk-Based Processing
```python

# Initialize with empty caches
state_cache = torch.zeros(1, 1024, 16)  # (batch, d_inner, d_state)
conv_cache = torch.zeros(1, 1024, 3)     # (batch, d_inner, d_conv-1)

# Process chunks (e.g. 8 tokens at once)
x_chunk = torch.rand(1, 8, 512)
output, new_state, new_conv = model(x_chunk, state_cache, conv_cache)
```
