jax>=0.4.13
jaxlib>=0.4.13
numpy>=1.20.0

[gpu]
cuda-python
