Metadata-Version: 2.1
Name: mamba-former
Version: 0.0.3
Summary: Paper - Pytorch
Home-page: https://github.com/kyegomez/MambaFormer
License: MIT
Keywords: artificial intelligence,deep learning,optimizers,Prompt Engineering
Author: Kye Gomez
Author-email: kye@apac.ai
Requires-Python: >=3.9,<4.0
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Dist: einops
Requires-Dist: torch
Requires-Dist: zetascale (==2.2.7)
Project-URL: Documentation, https://github.com/kyegomez/MambaFormer
Project-URL: Repository, https://github.com/kyegomez/MambaFormer
Description-Content-Type: text/markdown

[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

# MambaFormer
Implementation of MambaFormer in Pytorch ++ Zeta from the paper: "Can Mamba Learn How to Learn? A Comparative Study on In-Context Learning Tasks"

## install
`pip3 install mamba-former`

## usage
```python
import torch
from mamba_former.main import MambaFormer

# Forward pass example
x = torch.randint(1, 1000, (1, 100))  # Token
# Tokens are integers representing input data

# Model
model = MambaFormer(
    dim=512,  # Dimension of the model
    num_tokens=1000,  # Number of unique tokens in the input data
    depth=6,  # Number of transformer layers
    d_state=512,  # Dimension of the transformer state
    d_conv=128,  # Dimension of the convolutional layer
    heads=8,  # Number of attention heads
    dim_head=64,  # Dimension of each attention head
    return_tokens=True,  # Whether to return the tokens in the output
)

# Forward pass
out = model(x)  # Perform a forward pass through the model

# If training
# out = model(x, return_loss=True)  # Perform a forward pass and calculate the loss

# Print the output
print(out)  # Print the output tensor
print(out.shape)  # Print the shape of the output tensor

```


# License
MIT

