Metadata-Version: 2.1
Name: mm1-torch
Version: 0.0.4
Summary: MM1 - Pytorch
Home-page: https://github.com/kyegomez/mm1
License: MIT
Keywords: artificial intelligence,deep learning,optimizers,Prompt Engineering
Author: Kye Gomez
Author-email: kye@apac.ai
Requires-Python: >=3.6,<4.0
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Dist: einops
Requires-Dist: torch
Requires-Dist: zetascale
Project-URL: Documentation, https://github.com/kyegomez/mm1
Project-URL: Repository, https://github.com/kyegomez/mm1
Description-Content-Type: text/markdown

[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

# MM1 
PyTorch Implementation of the paper "MM1: Methods, Analysis & Insights from Multimodal LLM Pre-training".

`img -> encoder -> connector -> llm -> tokens` 

## install
`pip3 install mm1-torch`

## usage
```python

import torch
from mm1_torch.main import MM1

# Tensors
x = torch.randint(0, 100, (1, 512))
img = torch.randn(1, 3, 224, 224)

# Create a model
model = MM1(
    dim=512,
    depth=12,
    heads=8,
    dim_head=64,
    dropout=0.1,
    num_experts=4,
    num_experts_per_tok=2,
    encoder_dim=512,
    encoder_depth=12,
    encoder_heads=8,
)


# Forward
out = model(x, img)
print(out.shape)  # torch.Size([2, 3, 512])
print(out)
```

### `CAbstractor`

```python
import torch
from mm1_torch.main import CAbstractor

# Tensors
x = torch.randn(1, 100, 512)

# Create a model
model = CAbstractor(
    dim=512,
    depth=12,
    heads=8,
)


# Forward
out = model(x)
print(out.shape)

```


# License
MIT

