chex~=0.1.84
datasets~=2.14.7
einops~=0.6.1
flax~=0.7.5
ipython~=8.17.2
jax>=0.4.16
jaxlib>=0.4.16
msgpack~=1.0.7
numpy~=1.26.2
optax~=0.1.7
setuptools~=68.1.2
tqdm~=4.64.1
typing~=3.7.4.3
