einops>=0.6
scipy
torch>=1.13
jax
