carabiner-tools[mpl,pd]>=0.0.4
numpy
tqdm

[jax]
jax>=0.5.3

[jax_cuda12]
jax[cuda12]>=0.5.3

[jax_cuda12_local]
jax[cuda12_local]>=0.5.3

[torch]
torch>=2.4
torchvision
lightning
tensorboard
