numpy>=1.20.0
jax>=0.3.22
optax>=0.1.0
jaxlib>=0.3.22
dm-haiku>=0.0.9

[doc]
sphinx~=4.2.0
myst-parser
furo
