# Simple Math SFT Example Dependencies
#
# SpotJAX automatically uses the TPU JAX releases URL when installing.
# For local development: pip install -r requirements.txt

# JAX with TPU support - bonsai requires JAX 0.8+
# SpotJAX adds the TPU releases index automatically
jax[tpu]>=0.8.0

# Flax NNX (modern neural network API)
flax>=0.8.0

# Bonsai - JAX model implementations
# NOTE: bonsai has broken package metadata, installed via spotax_setup.sh instead

# Optax - JAX optimizer library
optax>=0.2.0

# Orbax - Checkpointing
orbax-checkpoint>=0.5.0

# Grain - Data loading
grain>=0.2.0

# HuggingFace Transformers (for tokenizer)
transformers>=4.40.0

# HuggingFace Datasets (for GSM8K)
datasets>=2.18.0

# Utilities
numpy>=1.24.0
gcsfs>=2024.2.0  # For GCS checkpoint access
