Metadata-Version: 2.4
Name: jaxboost
Version: 0.4.0
Summary: Next-generation differentiable gradient boosting with JAX
Project-URL: Homepage, https://github.com/jxucoder/jaxboost
Project-URL: Documentation, https://jxucoder.github.io/jaxboost/
Project-URL: Repository, https://github.com/jxucoder/jaxboost
Project-URL: Issues, https://github.com/jxucoder/jaxboost/issues
Author-email: J Xu <jxucoder@gmail.com>
License-Expression: Apache-2.0
License-File: LICENSE
Keywords: differentiable-programming,gpu,gradient-boosting,jax,machine-learning,soft-trees
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Requires-Dist: chex>=0.1.8
Requires-Dist: jax>=0.4.20
Requires-Dist: jaxlib>=0.4.20
Requires-Dist: optax>=0.1.7
Provides-Extra: benchmark
Requires-Dist: scikit-learn>=1.3.0; extra == 'benchmark'
Requires-Dist: xgboost>=2.0.0; extra == 'benchmark'
Provides-Extra: dev
Requires-Dist: mypy>=1.6.0; extra == 'dev'
Requires-Dist: pytest>=7.4.0; extra == 'dev'
Requires-Dist: ruff>=0.1.0; extra == 'dev'
Provides-Extra: docs
Requires-Dist: mkdocs-material>=9.5.0; extra == 'docs'
Requires-Dist: mkdocs>=1.5.0; extra == 'docs'
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == 'docs'
Provides-Extra: macos
Requires-Dist: jax-metal>=0.1.1; extra == 'macos'
Provides-Extra: ode
Requires-Dist: diffrax>=0.5.0; extra == 'ode'
Provides-Extra: polars
Requires-Dist: polars>=0.19.0; extra == 'polars'
Requires-Dist: pyarrow>=14.0.0; extra == 'polars'
Description-Content-Type: text/markdown

# JAXBoost

[![Tests](https://github.com/jxucoder/jaxboost/actions/workflows/tests.yml/badge.svg)](https://github.com/jxucoder/jaxboost/actions/workflows/tests.yml)
[![Lint](https://github.com/jxucoder/jaxboost/actions/workflows/lint.yml/badge.svg)](https://github.com/jxucoder/jaxboost/actions/workflows/lint.yml)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/)
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://www.apache.org/licenses/LICENSE-2.0)

**JAX autodiff for XGBoost/LightGBM objectives.**

Write a loss function, get gradients and Hessians automatically. No manual derivation needed.

Works with **XGBoost** and **LightGBM**.

## Install

```bash
pip install jaxboost
```

## Quick Start

### XGBoost

```python
import xgboost as xgb
import jax.numpy as jnp
from jaxboost import auto_objective, focal_loss, huber, quantile

# Prepare your data
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}

# Built-in objectives - just use them
model = xgb.train(params, dtrain, num_boost_round=100, obj=focal_loss.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=huber.xgb_objective)
model = xgb.train(params, dtrain, num_boost_round=100, obj=quantile(0.9).xgb_objective)

# Custom objective - write the loss, autodiff handles the rest
@auto_objective
def asymmetric_mse(y_pred, y_true, alpha=0.7):
    error = y_true - y_pred
    return jnp.where(error > 0, alpha * error**2, (1 - alpha) * error**2)

model = xgb.train(params, dtrain, num_boost_round=100, obj=asymmetric_mse.xgb_objective)
```

### LightGBM

```python
import lightgbm as lgb
from jaxboost import huber

train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}

model = lgb.train(params, train_data, num_boost_round=100, fobj=huber.lgb_objective)
```


## Available Objectives

### Regression
| Objective | Description |
|-----------|-------------|
| `mse` | Mean squared error |
| `huber` | Huber loss (robust to outliers) |
| `pseudo_huber` | Smooth approximation of Huber loss |
| `log_cosh` | Log-cosh loss |
| `mae_smooth` | Smooth approximation of MAE |
| `quantile(q)` | Quantile regression |
| `asymmetric(alpha)` | Asymmetric squared error |
| `poisson` | Poisson deviance (count data) |
| `gamma` | Gamma deviance (positive continuous) |
| `tweedie(p)` | Tweedie deviance |

### Binary Classification
| Objective | Description |
|-----------|-------------|
| `focal_loss` | Focal loss for imbalanced data |
| `binary_crossentropy` | Standard log loss |
| `weighted_binary_crossentropy` | Weighted binary cross-entropy |
| `hinge_loss` | SVM-style hinge loss |

### Multi-class Classification
| Objective | Description |
|-----------|-------------|
| `softmax_cross_entropy` | Standard multi-class |
| `focal_multiclass` | Focal loss for multi-class |
| `label_smoothing(eps)` | Label smoothing regularization |
| `class_balanced` | Class-balanced loss |

### Survival Analysis
| Objective | Description |
|-----------|-------------|
| `aft` | Accelerated failure time (log-normal) |
| `weibull_aft` | Weibull AFT model |

### Ordinal Regression
| Objective | Description |
|-----------|-------------|
| `ordinal_logit` | Cumulative Link Model (logit link) |
| `ordinal_probit` | Cumulative Link Model (probit link) |
| `qwk_ordinal` | QWK-aligned Expected Quadratic Error |
| `squared_cdf_ordinal` | CRPS / Ranked Probability Score |
| `hybrid_ordinal` | NLL + EQE hybrid |
| `sord_objective` | SORD (Soft Ordinal) from SLACE paper |
| `oll_objective` | OLL (Ordinal Log-Loss) from SLACE paper |
| `slace_objective` | SLACE (AAAI 2025) |

### Multi-task Learning
| Objective | Description |
|-----------|-------------|
| `multi_task_regression` | Multiple regression targets |
| `multi_task_classification` | Multiple classification targets |
| `multi_task_huber` | Multi-task Huber loss |
| `multi_task_quantile` | Multi-task quantile loss |
| `MaskedMultiTaskObjective` | Handle missing labels |

### Uncertainty Estimation
| Objective | Description |
|-----------|-------------|
| `gaussian_nll` | Predict mean + variance |
| `laplace_nll` | Predict median + scale |

## Ordinal Regression

XGBoost/LightGBM have no native ordinal objective. JAXBoost implements proper [Cumulative Link Models](https://en.wikipedia.org/wiki/Ordered_logit):

```python
from jaxboost import ordinal_logit, qwk_ordinal

# Wine quality: 6 ordered classes (3-8 mapped to 0-5)
ordinal = ordinal_logit(n_classes=6)
ordinal.init_thresholds_from_data(y_train)

# Works with XGBoost
model = xgb.train(params, dtrain, obj=ordinal.xgb_objective)

# Or LightGBM
model = lgb.train(params, train_data, fobj=ordinal.lgb_objective)

# Get class probabilities
probs = ordinal.predict_proba(model.predict(dtest))
classes = ordinal.predict(model.predict(dtest))
```

## Evaluation Metrics

When using custom objectives, use matching evaluation metrics:

```python
from jaxboost import ordinal_logit
from jaxboost.metric import qwk_metric, mae_metric

ordinal = ordinal_logit(n_classes=6)
ordinal.init_thresholds_from_data(y_train)

# Train with custom metric monitoring
model = xgb.train(
    {'disable_default_eval_metric': 1, 'max_depth': 4},  # Disable default metrics!
    dtrain,
    obj=ordinal.xgb_objective,
    custom_metric=ordinal.qwk_metric.xgb_metric,  # Built-in QWK metric
    evals=[(dtest, 'test')]
)
```

### Available Metrics
| Category | Metrics |
|----------|---------|
| **Ordinal** | `qwk_metric`, `ordinal_mae_metric`, `ordinal_accuracy_metric`, `adjacent_accuracy_metric` |
| **Classification** | `auc_metric`, `f1_metric`, `accuracy_metric`, `precision_metric`, `recall_metric` |
| **Regression** | `mse_metric`, `rmse_metric`, `mae_metric`, `r2_metric` |
| **Bounded** | `bounded_mse_metric`, `out_of_bounds_metric` |

## Custom Objectives

The `@auto_objective` decorator turns any loss function into an XGBoost/LightGBM objective:

```python
import xgboost as xgb
import lightgbm as lgb
import jax.numpy as jnp
from jaxboost import auto_objective

@auto_objective
def my_custom_loss(y_pred, y_true, **kwargs):
    # Write your loss here - JAX computes grad/hess automatically
    return (y_pred - y_true) ** 2

# Use with XGBoost
dtrain = xgb.DMatrix(X_train, label=y_train)
params = {"max_depth": 4, "eta": 0.1}
model = xgb.train(params, dtrain, num_boost_round=100, obj=my_custom_loss.xgb_objective)

# Use with LightGBM
train_data = lgb.Dataset(X_train, label=y_train)
params = {"max_depth": 4, "learning_rate": 0.1}
model = lgb.train(params, train_data, num_boost_round=100, fobj=my_custom_loss.lgb_objective)

# Pass parameters
model = xgb.train(
    params, dtrain, num_boost_round=100,
    obj=my_custom_loss.get_xgb_objective(alpha=0.5)
)
```

## Multi-class Example

```python
import xgboost as xgb
import jax
import jax.numpy as jnp
from jaxboost import multiclass_objective

@multiclass_objective(num_classes=3)
def custom_multiclass(logits, label):
    # logits: (num_classes,), label: scalar
    probs = jax.nn.softmax(logits)
    return -jnp.log(probs[label] + 1e-7)

dtrain = xgb.DMatrix(X_train, label=y_train)
model = xgb.train(
    {"num_class": 3, "max_depth": 4, "eta": 0.1},
    dtrain,
    num_boost_round=100,
    obj=custom_multiclass.xgb_objective
)
```

## Sklearn Interface

Use custom objectives with `XGBClassifier` and `XGBRegressor`:

```python
from xgboost import XGBClassifier
from jaxboost import focal_loss

clf = XGBClassifier(
    objective=focal_loss.sklearn_objective,
    n_estimators=100,
    max_depth=4
)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
```

## Why jaxboost?

| Traditional Approach | jaxboost |
|---------------------|----------|
| Derive gradients by hand | Write loss, get gradients free |
| Derive Hessians by hand | Write loss, get Hessians free |
| Error-prone math | JAX autodiff is correct by construction |
| One loss = hours of work | One loss = 5 lines of code |

## Benchmark Results

JAXBoost shines when XGBoost/LightGBM have **no native solution**:

### Bounded Regression (Proportions in [0, 1])

Predicting proportions where standard MSE can predict outside valid range.

| Model | MSE | Out-of-Bounds | Code |
|-------|-----|---------------|------|
| **JAXBoost Soft CE** | **0.0181** | 0% | 5 lines |
| Native MSE + Clip | 0.0201 | 0% | post-hoc fix |
| Native MSE | 0.0201 | 4.9% | - |

**9.5% improvement** + guaranteed valid outputs.

```python
@auto_objective
def soft_crossentropy(y_pred, y_true):
    mu = sigmoid(y_pred)
    return -(y_true * jnp.log(mu) + (1 - y_true) * jnp.log(1 - mu))
```

### Ordinal Regression (Wine Quality)

Predicting ordered categories (ratings 3-8) with Quadratic Weighted Kappa.

| Model | QWK | Probabilistic |
|-------|-----|---------------|
| Regression + OptRounder | 0.55 | No |
| **JAXBoost Squared CDF** | **0.54** | **Yes** |
| Native Multi-class | 0.51 | Yes |
| Native Regression | 0.48 | No |

JAXBoost ordinal objectives provide **proper probability distributions** over classes.

### When to Use JAXBoost

| Problem | XGBoost/LightGBM Native? | JAXBoost Advantage |
|---------|--------------------------|-------------------|
| Bounded regression [0,1] | ❌ No | ✅ 9.5% better MSE |
| Ordinal regression | ❌ No | ✅ Probabilistic outputs |
| Multi-task + missing labels | ❌ No | ✅ Proper masking |
| Custom business metrics | ❌ No | ✅ 5 lines of code |

📊 [Full benchmark details →](https://jxucoder.github.io/jaxboost/benchmarks/)

## Requirements

- Python >= 3.10
- JAX >= 0.4.20

## Documentation

Full documentation available at: https://jxucoder.github.io/jaxboost/

## License

Apache 2.0
