jax==0.4.38
jaxlib==0.4.38
numpy~=1.26.4
flax<=0.10.6
joblib~=1.3.2

[dev]
pytest~=8.0.0
pytest-datadir~=1.5.0
coverage~=7.4.3

[huggingface]
huggingface-hub~=0.21.3
