jax>=0.2.21
optax>=0.1.0
chex>=0.0.8
tensorflow-probability==0.15.0
tqdm
ml-collections==0.1.0

[dev]
black
isort
pylint
flake8

[docs]
furo==2020.12.30b24
nbsphinx==0.8.1
nb-black==1.0.7
matplotlib==3.3.3
sphinx-copybutton==0.3.5
gpviz==0.0.1

[tests]
pytest
