tensorflow>=2.4.0
tensorflow-datasets>=4.2.0
dm-haiku>=0.0.3
optax>=0.0.6
