Metadata-Version: 2.1
Name: boax
Version: 0.0.2
Summary: Boax is a Bayesian Optimization library for JAX.
Project-URL: homepage, https://github.com/Lando-L/boax
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy>=1.26.2
Requires-Dist: jax>=0.4.20
Requires-Dist: jaxlib>=0.4.20
Requires-Dist: chex>=0.1.84
Provides-Extra: docs
Requires-Dist: ipython>=8.18.1; extra == "docs"
Requires-Dist: matplotlib>=3.8.2; extra == "docs"
Requires-Dist: myst-nb>=1.0.0; extra == "docs"
Requires-Dist: optax>=0.1.7; extra == "docs"
Requires-Dist: sphinx>=6.2.1; extra == "docs"
Requires-Dist: sphinx-autodoc-typehints>=1.23.0; extra == "docs"
Requires-Dist: sphinx-book-theme>=1.0.1; extra == "docs"
Provides-Extra: examples
Requires-Dist: matplotlib>=3.8.2; extra == "examples"
Requires-Dist: optax>=0.1.7; extra == "examples"
Provides-Extra: tests
Requires-Dist: absl-py>=2.0.0; extra == "tests"
Requires-Dist: pytest>=7.4.3; extra == "tests"
Requires-Dist: pytest-cov>=4.1.0; extra == "tests"

# Boax: A Bayesian Optimization library for JAX.

[**Overview**](#overview)
| [**Installation**](#installation)
| [**Getting Started**](#getting-started)
| [**Documentation**](https://boax.readthedocs.io/en/latest/)

*Boax is currently in early alpha and under active development!*

## Overview

Boax is a composable library of core components for Bayesian Optimization
that is **designed for flexibility**. It comes with a low-level interfaces for:

* **Fitting a surrogate model to data** (`boax.prediction`):
  * Kernels
  * Mean Functions
  * Surrogate Models
* **Constructing and optimizing acquisition functions** (`boax.optimization`):
  * Acquisition Functions
  * Maximizers
  * Samplers

## Installation

You can install the latest released version of Boax from PyPI via:

```sh
pip install boax
```

or you can install the latest development version from GitHub:

```sh
pip install git+https://github.com/Lando-L/boax.git
```

## Getting Started

Here is a quick start example of the two main compoments that form the Bayesian optimization loop.
For more details check out the [docs](https://boax.readthedocs.io/en/latest/) and our [tutorials](https://github.com/Lando-L/boax/blob/main/examples).

1. Create a dataset from a noisy sinusoid.

```python
from functools import partial

from jax import config

# Double precision is highly recommended.
config.update("jax_enable_x64", True)

from jax import jit
from jax import lax
from jax import nn
from jax import numpy as jnp
from jax import random
from jax import scipy
from jax import value_and_grad
from jax import vmap

import optax
import matplotlib.pyplot as plt

from boax.prediction import kernels, means, models
from boax.optimization import acquisitions, maximizers, samplers

bounds = jnp.array([[-3, 3]])

def objective(x):
  return jnp.sin(4 * x[..., 0]) + jnp.cos(2 * x[..., 0])

sample_key, noise_key, maximizer_key = random.split(random.key(0), 3)
x_train = random.uniform(sample_key, minval=bounds[0, 0], maxval=bounds[0, 1], shape=(10, 1))
y_train = objective(x_train) + 0.3 * random.normal(noise_key, shape=(10,))
```

2. Fit a Gaussian Process surrogate model to the training dataset.

```python
def prior(amplitude, length_scale, noise):
  return models.gaussian_process(
    means.zero(),
    kernels.scale(nn.softplus(amplitude), kernels.rbf(nn.softplus(length_scale))),
    nn.softplus(noise),
  )

optimizer = optax.adam(0.01)

def target_log_prob(params):
  mean, cov = prior(**params)(x_train)
  return -scipy.stats.multivariate_normal.logpdf(y_train, mean, cov)

params = {
  'amplitude': jnp.zeros(()),
  'length_scale': jnp.zeros(()),
  'noise': jnp.array(-5.),
}

def train_step(state, iteration):
  loss, grads = value_and_grad(target_log_prob)(state[0])
  updates, opt_state = optimizer.update(grads, state[1])
  params = optax.apply_updates(state[0], updates)

  return (params, opt_state), loss

(next_params, next_opt_state), history = lax.scan(
  jit(train_step),
  (params, optimizer.init(params)),
  jnp.arange(500)
)
```

3. Construct and optimize an UCB acquisition function.
```python
surrogate = models.gaussian_process_regression(
  x_train,
  y_train,
  means.zero(),
  kernels.scale(nn.softplus(next_params['amplitude']), kernels.rbf(nn.softplus(next_params['length_scale']))),
  nn.softplus(next_params['noise']),
)

acqf = acquisitions.upper_confidence_bound(2.0, surrogate)
maximizer = maximizers.bfgs(bounds, q=1, num_restarts=25, num_raw_samples=500)

init_candidates = maximizer.init(maximizer_key, acqf)
candidates, values = maximizer.maximize(init_candidates, acqf)
```

## Citing Boax

To cite Boax please use the citation:

```bibtex
@software{boax2023github,
  author = {Lando L{\"o}per},
  title = {{B}oax: A Bayesian Optimization library for {JAX}},
  url = {https://github.com/Lando-L/boax},
  version = {0.0.2},
  year = {2023},
}
```

In the above bibtex entry, the version number
is intended to be that from [boax/version.py](https://github.com/Lando-L/boax/blob/main/boax/version.py), and the year corresponds to the project's open-source release.
