
jax>=0.3.5
pytreeclass >= 0.0.9rc1
