jax>=0.4.25
jaxlib
matplotlib
numpy<2
scipy
tensorflow_probability
tqdm
optax
jaxopt

[examples]

[tests]
