distrax>=0.1.5
flax>=0.8.1
jax>=0.4.25
optax>=0.1.7
ott-jax>=0.4.9
numpy
pytest
