torch<2.4,>2
jaxtyping
numpy
graphviz
