jax==0.4.28
jaxtyping
jax-tqdm
optax
jaxopt
numpy
scipy
matplotlib
pytest

[dev]
