flax~=0.7.0
jax~=0.4.13
optax~=0.1.7
orbax-checkpoint~=0.2.7
numpy~=1.24.2
scikit-learn~=1.2.2
Pillow~=9.4.0
torch~=2.0.0
torchvision~=0.15.1
simple_parsing~=0.1.3
tqdm
loguru
matplotlib

[dev]
pytest
pre-commit

[plots]
iceberg-dsl==0.0.6

[wandb]
wandb
