
[all]
jax>=0.4.14

[all:platform_machine != "arm64" or platform_system != "Darwin"]
jaxlib>=0.4.14

[all:platform_system != "Darwin" or platform_machine != "arm64"]
torch>=1.12

[dev]
numpy>=1.23
pytest>=7.4
pytest-cov>=4.1
hypothesis>=6.88
black>=23.7
ruff>=0.4.0
isort>=5.12
mypy>=1.4
pre-commit>=3.3

[jax]
jax>=0.4.14

[jax:platform_machine != "arm64" or platform_system != "Darwin"]
jaxlib>=0.4.14

[torch]

[torch:platform_system != "Darwin" or platform_machine != "arm64"]
torch>=1.12
