jax
jaxlib
tfp-nightly
numpy
scipy
optax
jaxtyping>=0.2.15
tqdm
datasets
