numpy>=2.1.0
jax>=0.4.31
flax>=0.9.0
optax>=0.2.3
datasets>=2.21.0
pillow>=10.4.0
ipython>=8.26.0
ipykernel>=6.29.5
ipywidgets>=8.1.5
mediapy>=1.2.2

[cuda]
jax[cuda12]

[dev]
pytest>=8.3.2
pytest-cov>=5.0.0
ruff>=0.6.4
mypy>=1.11.2
pre-commit

[tpu]
jax[tpu]
