flax
immutabledict==4.2.0
jax[cuda]
jaxtyping
numpy==1.26.0
optax
scikit-learn
torch==2.4.0
tqdm
transformers==4.43.3

[notebook]
datasets
huggingface_hub
ipywidgets
pandas
tensorflow

[notebook-local]
synthid-text[notebook]
notebook

[test]
absl-py
mock
pytest
