IPython
chex~=0.1.7
datasets
einops
einops
flax>=0.7.1
jax>=0.4.10
ml_collections
msgpack
optax>=0.1.7
transformers>=4.34.0
typing
