torch>=1.9.0
torchinfo
numpy
scikit-learn
timm==1.0.3
triton
