jax[cpu] >= 0.3.2, <= 0.4.14
numpy >= 1.20.0, < 1.25.0
scipy >= 1.5.0, < 1.11.0
