jax>=0.4.20
jaxlib>=0.4.20
optax>=0.2.0
flax>=0.8.0

[dev]
pytest>=7.4
