scikit-learn>=1.0.2
pandas
jax[cpu]
tqdm
optax
pydantic<2,>=1.9.0
jax-tqdm
einops
keras<3.4.0,>=3.0.3
matplotlib
seaborn
dm-haiku

[dev]
torch>=1.7.0
causalgraphicalmodels
pre-commit
datasets
nbdev>=v2.3.28
jupyter
flax>=0.7.1
