jax>=0.1.67
jaxlib>=0.1.47
jaxopt>=0.5.5
numpy!=1.25.0,>=1.18.4
flax>=0.5.2
optax>=0.1.1

[:python_version >= "3.9"]
lineax>=0.0.1

[dev]
pre-commit>=2.16.0
tox>=4

[docs]
sphinx>=4.0
sphinx-book-theme>=1.0.1
sphinx_autodoc_typehints>=1.12.0
sphinx-copybutton>=0.5.1
sphinxcontrib-bibtex>=2.5.0
sphinxcontrib-spelling>=7.7.0
myst-nb>=0.17.1

[test]
pytest
pytest-xdist
pytest-cov
pytest-memray
coverage[toml]
chex
networkx>=2.5
scikit-learn>=1.0

[test:python_version < "3.11"]
tslearn>=0.5

[test:python_version >= "3.9"]
lineax
