absl-py>=0.9.0
chex>=0.0.5
# Change the sign once TFP is fixed
jax<=0.2.11
jaxlib>=0.1.37
numpy>=1.18.0
tensorflow-probability>=0.12.1
