numpy==2.1.3
scipy==1.14.1
jax[cpu]==0.4.35
optax==0.2.4
flax==0.10.2
pytest==8.3.4

[doc]
jupyterlab==4.3.0
matplotlib==3.9.2
scikit-learn==1.5.2
sphinx
sphinx-rtd-theme
sphinx-autobuild
myst-parser
sphinx_design
sphinx-autodoc-typehints
nbsphinx
nbsphinx-link
pandoc

[gpu]
jax[cuda12]==0.4.35
