jax>=0.4.23
jaxlib>=0.4.23
optax~=0.2.2
msgpack~=1.0.7
ipython~=8.17.2
tqdm~=4.64.1
numpy~=1.26.2
scipy==1.13.1
typing~=3.7.4.3
flax>=0.8.0
chex>=0.1.7
einops==0.8.0
ml-collections==0.1.1
plum-dispatch==2.3.2
termcolor
