Metadata-Version: 2.1
Name: dilax
Version: 0.1.2
Summary: Differentiable (binned) likelihoods in JAX.
Project-URL: Homepage, https://github.com/pfackeldey/dilax
Project-URL: Bug Tracker, https://github.com/pfackeldey/dilax/issues
Project-URL: Discussions, https://github.com/pfackeldey/dilax/discussions
Project-URL: Changelog, https://github.com/pfackeldey/dilax/releases
Author-email: Peter Fackeldey <peter.fackeldey@rwth-aachen.de>
License-File: LICENSE
Classifier: Development Status :: 1 - Planning
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Typing :: Typed
Requires-Python: >=3.9
Requires-Dist: equinox>=0.10.6
Requires-Dist: jaxopt>=0.6
Provides-Extra: dev
Requires-Dist: pytest-cov>=3; extra == 'dev'
Requires-Dist: pytest>=6; extra == 'dev'
Provides-Extra: docs
Requires-Dist: autodocsumm~=0.2; extra == 'docs'
Requires-Dist: myst-parser~=0.18; extra == 'docs'
Requires-Dist: sphinx-book-theme~=0.3; extra == 'docs'
Requires-Dist: sphinx-copybutton; extra == 'docs'
Requires-Dist: sphinx-lfs-content~=1.1; extra == 'docs'
Requires-Dist: sphinx~=4.3; extra == 'docs'
Provides-Extra: test
Requires-Dist: pytest-cov>=3; extra == 'test'
Requires-Dist: pytest>=6; extra == 'test'
Description-Content-Type: text/markdown

# dilax

[![Documentation Status](https://readthedocs.org/projects/dilax/badge/?version=latest)](https://dilax.readthedocs.io/en/latest/?badge=latest)
[![Actions Status][actions-badge]][actions-link]
[![PyPI version][pypi-version]][pypi-link]
[![PyPI platforms][pypi-platforms]][pypi-link]

Differentiable (binned) likelihoods in JAX.

## Installation

```bash
python -m pip install dilax
```

From source:

```bash
git clone https://github.com/pfackeldey/dilax
cd dilax
python -m pip install .
```

## Usage - Model definition and fitting

See more in `examples/`

_dilax_ in a nutshell:

```python3
import jax
import jax.numpy as jnp
import equinox as eqx

from dilax.likelihood import NLL
from dilax.model import Model, Result
from dilax.optimizer import JaxOptimizer
from dilax.parameter import Parameter, gauss, modifier, unconstrained
from dilax.util import HistDB


jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(Model):
    def __call__(self, processes: HistDB, parameters: dict[str, Parameter]) -> Result:
        res = Result()

        # signal
        mu_mod = modifier(name="mu", parameter=parameters["mu"], effect=unconstrained())
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = modifier(name="sigma", parameter=parameters["sigma"], effect=gauss(0.2))
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = HistDB({"signal": jnp.array([10.0]), "background": jnp.array([50.0])})
parameters = {
    "mu": Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = JaxOptimizer.make(name="ScipyMinimize", settings={"method": "trust-constr"})

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values/parameters
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
fast_grad_nll_prefit = eqx.filter_grad(nll)
print(fast_grad_nll_prefit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}

# gradients of "postfit" model:
postfit_nll = NLL(model=model.update(values=values), observation=jnp.array([64.0]))
fast_grad_nll_postfit = eqx.filter_grad(eqx.filter_jit(postfit_nll))
print(fast_grad_nll_postfit({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}
```

## Contributing

See [CONTRIBUTING.md](CONTRIBUTING.md) for instructions on how to contribute.

## License

Distributed under the terms of the [BSD license](LICENSE).

<!-- prettier-ignore-start -->
[actions-badge]:            https://github.com/pfackeldey/dilax/workflows/CI/badge.svg
[actions-link]:             https://github.com/pfackeldey/dilax/actions
[pypi-link]:                https://pypi.org/project/dilax/
[pypi-platforms]:           https://img.shields.io/pypi/pyversions/dilax
[pypi-version]:             https://img.shields.io/pypi/v/dilax
<!-- prettier-ignore-end -->
