jax>=0.4.0
pytreeclass>=0.3.0
kernex
