numpy>=1.17
scipy
jax<0.3.14,>=0.1.65
jaxlib<0.3.14,>=0.1.65
jaxopt>=0.2
