jax>=0.4.0
jaxlib>=0.4.0
fmmax>=1.0.0
packaging

[test]
pre-commit
pytest-cov
ruff
optax
mypy
