numpy
absl-py
ml-collections
wandb
einops
jaxtyping
opt-einsum
transformers
torch
tqdm
jax[cuda12]==0.6.1
optax==0.2.5
flax==0.10.6
chex==0.1.89
datasets>=4.1.1
