jax>=0.4.20
jaxlib>=0.4.20
numpy~=1.26.2
typing~=3.7.4.3
flax>=0.7.5
chex>=0.1.7
ipython~=8.17.2
datasets~=2.14.7
einops~=0.6.1
msgpack~=1.0.7
tqdm~=4.64.1
optax~=0.1.7
ml_collections==0.1.1
plum-dispatch==2.3.2
termcolor
