numpy>=2.3.5
scipy>=1.16.3
jax[cpu]>=0.8.1
optax>=0.2.6
flax>=0.12.1

[doc]
jupyterlab
matplotlib
scikit-learn
pytest
sphinx
sphinx-rtd-theme
sphinx-autobuild
myst-parser
sphinx_design
sphinx-autodoc-typehints
nbsphinx
nbsphinx-link
pandoc

[gpu]
jax[cuda12]>=0.8.1
