IPython
chex~=0.1.7
datasets
einops
einops
flax>=0.7.1
jax>=0.4.10
ml_collections
msgpack
numpy
optax>=0.1.7
torch>=2.0.0
transformers>=4.34.0
typing
