jax>=0.4.22
jaxlib>=0.4.22
numpy~=1.26.2
typing~=3.7.4.3
flax~=0.7.5
chex~=0.1.84
ipython~=8.17.2
datasets~=2.14.7
einops~=0.6.1
msgpack~=1.0.7
tqdm~=4.64.1
optax~=0.1.7
ml_collections==0.1.1
plum-dispatch==2.3.2
termcolor
