jax==0.3.2
jaxlib==0.3.2
optax>=0.1.0
chex==0.1.3
distrax>=0.1.2
tensorflow == 2.8.1
tensorflow-probability==0.16.0
tqdm>=4.0.0
ml-collections==0.1.0
protobuf==3.19.0