tensorflow==2.9.3
tensorflow-datasets==4.6.0
jax==0.3.13
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
jaxlib==0.3.10+cuda11.cudnn805