jax>=0.4.8
numpyro>=0.11.0
dm-haiku>=0.0.5
matplotlib>=3.1
