IPython>=8.17.2
chex~=0.1.7
datasets
einops
einops>=0.6.1
flax>=0.7.1
jax>=0.4.10
ml_collections
msgpack>=1.0.5
numpy
optax>=0.1.7
torch>=2.0.0
transformers>=4.34.0
typing>=3.10.0.0
