absl-py>=1.0.0
brax>=0.10.4
chex>=0.1.86
flax>=0.8.5
gym>=0.26.2
jax>=0.4.28
jaxlib>=0.4.28
jinja2>=3.1.4
jumanji>=0.3.1
numpy>=1.26.4
optax>=0.1.9
scikit-learn>=1.5.1
scipy>=1.10.1
tensorflow-probability>=0.24.0

[cuda12]
jax[cuda12]>=0.4.28
