absl-py
flax
jax>=0.1.76
jaxlib
ml_collections
numpy>=1.16.4
tensorflow>=2.3.0
tensorflow_datasets

[:python_version < "3.7"]
dataclasses

[pytorch]
torch>=1.2.0

[test]
dm-sonnet
pytest
torch>=1.2.0
