# ImageNet EfficientNet Example Dependencies
#
# SpotJAX automatically uses the TPU JAX releases URL when installing.
# For local development: pip install -r requirements.txt

# JAX with TPU support
jax[tpu]

# Flax NNX (modern neural network API)
flax

# Bonsai - JAX model implementations
# NOTE: bonsai has broken package metadata, installed via spotax_setup.sh instead

# timm - For loading pretrained EfficientNet weights
timm>=1.0.0

# Optax - JAX optimizer library
optax

# Orbax - Checkpointing
orbax-checkpoint

# Grain - Data loading (with ArrayRecord support)
grain
array-record

# msgpack - For efficient serialization (used by ArrayRecord pipeline)
msgpack

# PIL for image decoding and preprocessing
pillow

# Utilities
numpy>=1.24.0
gcsfs  # For GCS checkpoint access
