torch>=2.0
numpy
einops
opt_einsum

[flex_attention]
torch>=2.7

[xformers_attention]
xformers
