einops>=0.8.1
grain
jax-ai-stack[tfds]
ml-collections>=1.1.0

[all]
