jax>=0.4.16
jaxlib

[test]
pytest
