jax
jaxlib>=0.3
equinox>=0.10
tqdm
optax

[dev]
pytest
