numpy
matplotlib
jax
diffrax