jax>=0.4.0
jaxlib>=0.4.0
flax>=0.7.0
numpy>=1.20.0
ml_collections>=0.1.0

[all]
torch>=1.10.0

[dev]
pytest>=7.0.0
black>=22.0.0
isort>=5.10.0
flake8>=4.0.0

[torch]
torch>=1.10.0
