Metadata-Version: 2.1
Name: kernax
Version: 0.1
Summary: Regularized Stein thinning using JAX
Author: Brian Staber
Author-email: Brian Staber <brian.staber@safrangroup.com>
License: MIT License
Project-URL: homepage, https://gitlab.com/drti/kernax
Project-URL: documentation, https://kernax.readthedocs.io/en/latest/
Project-URL: repository, https://gitlab.com/drti/kernax
Keywords: machine learning,statistics,mcmc,thinning,Stein
Platform: Linux
Platform: Mac OS-X
Platform: Unix
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS
Classifier: Operating System :: POSIX
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Topic :: Scientific/Engineering
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Mathematics
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax
Requires-Dist: blackjax==0.9.6
Requires-Dist: numpy
Requires-Dist: scipy
Requires-Dist: typing-extensions>=4.4.0

<h1 align="center">Kernax: regularized Stein thinning</h1>

```python
import jax
import jax.numpy as jnp
rng_key = jax.random.PRNGKey(0)
x = jax.random.normal(rng_key, (1000,2))

from jax.scipy.stats import multivariate_normal
def logprob_fn(x):
    return multivariate_normal.logpdf(x, mean=jnp.zeros(2), cov=jnp.eye(2))
score_fn = jax.grad(logprob_fn)

score_values = jax.vmap(score_fn, 0)(x)

from kernax.utils import median_heuristic
lengthscale = jnp.array([median_heuristic(x)])

from kernax import SteinThinning
stein_fn = SteinThinning(x, score_values, lengthscale)
indices = stein_fn(100)

from kernax import laplace_log_p_softplus
log_p = jax.vmap(score_fn, 0)(x)
laplace_log_p_values = laplace_log_p_softplus(x, score_fn)

from kernax import RegularizedSteinThinning
reg_stein_fn = RegularizedSteinThinning(x, log_p, score_values, laplace_log_p_values, lengthscale)
indices = reg_stein_fn(100)
```

## Documentation

Documentation is available at [readthedocs](https://kernax.readthedocs.io/en/latest/?kernax=latest).

## Contributing

This code is not meant to be an evolving library. However, feel free to create issues and merge requests.

## Setup

All the requirements are listed in the file `env.yml`. It can be used to create a conda environement as follows.

```console
cd kernax-main
conda env create -n kernax -f env.yml
```
Activate the new environment:
```console
conda activate kernax
```
Then add the package to your `PYTHONPATH` or simply do
```console
pip install -e .
```
And test if it is working properly:
```
python -c "import kernax; print(dir(kernax))"
```

## Reproductibility

This code implements the regularized Stein thinning algorithm introduced in the paper [Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization](https://arxiv.org/pdf/2301.13528.pdf).

Please consider citing the paper when using this library:
```bibtex
@article{benard2023kernel,
  title={Kernel Stein Discrepancy thinning: a theoretical perspective of pathologies and a practical fix with regularization},
  author={B{\'e}nard, Cl{\'e}ment and Staber, Brian and Da Veiga, S{\'e}bastien},
  journal={arXiv preprint arXiv:2301.13528},
  year={2023}
}
```

All the numerical experiments presented in the [paper](https://arxiv.org/pdf/2301.13528.pdf) can be reproduced with the scripts made available in the example folder.

In particular:

* Figures 1, 2 & 3 can be reproduced with the script example/mog_randn.py

* Each experiment in Section 4 and Appendix 1 can be reproduced with the scripts gathered in the following folders:
    * Gaussian mixture: example/mog4_mcmc and example/mog4_mcmc_dim
    * Mixture of banana-shaped distributions: example/mobt2_mcmc and example/mobt2_mcmc_dim
    * Bayesian logistic regression: example/logistic_regression.py

* Two additional scripts are also available to reproduce figures shown in the supplementary material:
    * Figure 2: example/mog_weight_weights.py
    * Figure 6: example/mog4_mcmc_lambda
