chex>=0.1.86
dm-haiku>=0.0.5
fairscale>=0.4.6
flax==0.8.5
jax==0.4.30
jaxlib==0.4.30
numpy>=1.26.4
optax==0.2.3
scipy>=1.10.1
transformers>=4.30.2

[dev]
flake8==3.9.2
pytest==6.2.4
