einops>=0.3
jax>=0.2.10
flax>=0.3.2
