numpy
jax>=0.4.29
einops
flax
