Metadata-Version: 2.4
Name: jaxboost
Version: 0.1.0
Summary: Next-generation differentiable gradient boosting with JAX
Project-URL: Homepage, https://github.com/jxu/jaxboost
Project-URL: Repository, https://github.com/jxu/jaxboost
Project-URL: Issues, https://github.com/jxu/jaxboost/issues
Author-email: J Xu <jxucoder@gmail.com>
License-Expression: MIT
License-File: LICENSE
Keywords: differentiable-programming,gpu,gradient-boosting,jax,machine-learning,soft-trees
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: chex>=0.1.8
Requires-Dist: jax>=0.4.20
Requires-Dist: jaxlib>=0.4.20
Requires-Dist: optax>=0.1.7
Provides-Extra: benchmark
Requires-Dist: scikit-learn>=1.3.0; extra == 'benchmark'
Requires-Dist: xgboost>=2.0.0; extra == 'benchmark'
Provides-Extra: dev
Requires-Dist: mypy>=1.6.0; extra == 'dev'
Requires-Dist: pytest>=7.4.0; extra == 'dev'
Requires-Dist: ruff>=0.1.0; extra == 'dev'
Provides-Extra: docs
Requires-Dist: mkdocs-material>=9.5.0; extra == 'docs'
Requires-Dist: mkdocs>=1.5.0; extra == 'docs'
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == 'docs'
Provides-Extra: polars
Requires-Dist: polars>=0.19.0; extra == 'polars'
Requires-Dist: pyarrow>=14.0.0; extra == 'polars'
Description-Content-Type: text/markdown

# jaxboost

Differentiable gradient boosting in JAX.

⚠️ **This is a personal learning project, very much a work in progress.** There is no intention to replace production boosting libraries like XGBoost, LightGBM, or CatBoost. The main purpose is to learn JAX while rethinking gradient boosting from first principles. No guarantee of reliability or correctness for now. Issues welcome!

## What it is

A gradient boosting implementation using soft (differentiable) oblivious trees. The entire model is trained end-to-end with gradient descent via optax, rather than the traditional greedy tree-building approach.

Key characteristics:
- Soft routing with sigmoid functions (trees are differentiable)
- Oblivious tree structure (same split at each level)
- Hyperplane splits (linear combinations of features)
- Runs on GPU via JAX

## Installation

```bash
pip install jaxboost
```

Or from source:

```bash
git clone https://github.com/jxu/jaxboost.git
cd jaxboost
pip install -e .
```

## Usage

```python
from jaxboost import GBMTrainer, TrainerConfig

# Regression
trainer = GBMTrainer(task="regression")
model = trainer.fit(X_train, y_train)
predictions = model.predict(X_test)

# Classification
trainer = GBMTrainer(task="classification")
model = trainer.fit(X_train, y_train)
probabilities = model.predict(X_test)
classes = model.predict_class(X_test)
```

### Configuration

```python
config = TrainerConfig(
    n_trees=20,          # Number of trees
    depth=4,             # Tree depth
    learning_rate=0.01,  # Optimizer learning rate
    epochs=500,          # Training epochs
    patience=50,         # Early stopping patience
    verbose=True,        # Print progress
)
trainer = GBMTrainer(task="regression", config=config)
```

## Requirements

- Python >= 3.10
- JAX >= 0.4.20
- optax >= 0.1.7

## License

MIT
