numpy>=2.1.3
scipy>=1.14.1
jax[cpu]>=0.4.35
optax>=0.2.4
flax>=0.10.2

[doc]
jupyterlab
matplotlib
scikit-learn
pytest
sphinx
sphinx-rtd-theme
sphinx-autobuild
myst-parser
sphinx_design
sphinx-autodoc-typehints
nbsphinx
nbsphinx-link
pandoc

[gpu]
jax[cuda12]
