jax>=0.3.5
pytreeclass>=0.0.9rc1
