jax>=0.3.0
jaxlib>=0.3.0
chex
flax
numpy
pyyaml
matplotlib

[:python_version < "3.8"]
pickle5
