jax>=0.4.20
ml-collections>=0.1.1
tensorflow>=2.11.0
flax>=0.7.5
numpy==1.24.1
optax>=0.1.7
tqdm>=4.66.1
orbax-checkpoint==0.4.2
portpicker>=1.5.2
