matplotlib==3.9.0
jax==0.4.28
jaxlib==0.4.28
chex==0.1.86
numpy==1.26.4
scipy==1.13.1
pytest==8.2.1
pytest-cov==5.0.0
diffrax==0.5.1
jax-dataclasses==1.6.0
