jax>=0.4.25
jaxlib
typing_extensions

[testing]
flax
pytest
pytest-cov
