jax>=0.6.0
jaxlib
matplotlib
numpy
scipy
tfp_nightly

[examples]
scikit-learn
optax

[experimental]

[tests]
scikit-learn
networkx
psutil
pytest
flake8
