jax
jaxlib>=0.3
equinox
tqdm
optax
jaxtyping

[dev]
pytest
