torch>=2.0
numpy
einops
opt_einsum

[flex-attention]
torch>=2.7

[xformers-attention]
xformers
