tqdm>=4.60.0
scipy>=1.6.2
numpy>=1.19.2

[dev]
build
twine

[jax]
optax>=0.0.9
jax>=0.2.12
