numpy>=1.23.3
scipy>=1.9.1

[jax]
jax>=0.4.1
jaxlib>=0.4.1
optax>=0.1.4
tqdm>=4.64.1

[pytorch]
torch>=2.0.0
tqdm>=4.64.1
