jax
jaxlib>=0.3
equinox
tqdm
optax

[dev]
pytest
