Metadata-Version: 2.1
Name: dilax
Version: 0.1.6
Summary: Differentiable (binned) likelihoods in JAX.
Project-URL: Homepage, https://github.com/pfackeldey/dilax
Project-URL: Bug Tracker, https://github.com/pfackeldey/dilax/issues
Project-URL: Discussions, https://github.com/pfackeldey/dilax/discussions
Project-URL: Changelog, https://github.com/pfackeldey/dilax/releases
Author-email: Peter Fackeldey <peter.fackeldey@rwth-aachen.de>
License-File: LICENSE
Classifier: Development Status :: 1 - Planning
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: BSD License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering
Classifier: Typing :: Typed
Requires-Python: >=3.10
Requires-Dist: equinox>=0.10.6
Requires-Dist: jaxopt>=0.6
Provides-Extra: dev
Requires-Dist: pytest-cov>=3; extra == 'dev'
Requires-Dist: pytest>=6; extra == 'dev'
Provides-Extra: docs
Requires-Dist: autodocsumm~=0.2; extra == 'docs'
Requires-Dist: myst-parser~=0.18; extra == 'docs'
Requires-Dist: sphinx-book-theme~=0.3; extra == 'docs'
Requires-Dist: sphinx-copybutton; extra == 'docs'
Requires-Dist: sphinx-lfs-content~=1.1; extra == 'docs'
Requires-Dist: sphinx~=4.3; extra == 'docs'
Provides-Extra: test
Requires-Dist: pytest-cov>=3; extra == 'test'
Requires-Dist: pytest>=6; extra == 'test'
Description-Content-Type: text/markdown

# dilax

[![Documentation Status](https://readthedocs.org/projects/dilax/badge/?version=latest)](https://dilax.readthedocs.io/en/latest/?badge=latest)
[![Actions Status][actions-badge]][actions-link]
[![PyPI version][pypi-version]][pypi-link]
[![PyPI platforms][pypi-platforms]][pypi-link]

Differentiable (binned) likelihoods in JAX.

## Installation

```bash
python -m pip install dilax
```

From source:

```bash
git clone https://github.com/pfackeldey/dilax
cd dilax
python -m pip install .
```

## Usage - Model definition and fitting

See more in `examples/`

_dilax_ in a nutshell:

```python3
import equinox as eqx
import jax
import jax.numpy as jnp

import dilax as dlx

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(dlx.Model):
    def __call__(self, processes: dict, parameters: dict) -> dlx.Result:
        res = dlx.Result()

        # signal
        mu_mod = dlx.modifier(
            name="mu", parameter=parameters["mu"], effect=dlx.effect.unconstrained()
        )
        res.add(process="signal", expectation=mu_mod(processes["signal"]))

        # background
        bkg_mod = dlx.modifier(
            name="sigma", parameter=parameters["sigma"], effect=dlx.effect.gauss(0.2)
        )
        res.add(process="background", expectation=bkg_mod(processes["background"]))
        return res


# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
    "mu": dlx.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
    "sigma": dlx.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = dlx.likelihood.NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = dlx.optimizer.JaxOptimizer.make(
    name="ScipyMinimize", settings={"method": "trust-constr"}
)

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
#     'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
print(eqx.filter_grad(nll)({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}


# gradients of "postfit" model:
@eqx.filter_grad
@eqx.filter_jit
def grad_postfit_nll(where: dict[str, jax.Array]) -> dict[str, jax.Array]:
    nll = dlx.likelihood.NLL(
        model=model.update(values=values), observation=jnp.array([64.0])
    )
    return nll(where)


print(grad_postfit_nll({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}
```

## Contributing

See [CONTRIBUTING.md](CONTRIBUTING.md) for instructions on how to contribute.

## License

Distributed under the terms of the [BSD license](LICENSE).

<!-- prettier-ignore-start -->
[actions-badge]:            https://github.com/pfackeldey/dilax/workflows/CI/badge.svg
[actions-link]:             https://github.com/pfackeldey/dilax/actions
[pypi-link]:                https://pypi.org/project/dilax/
[pypi-platforms]:           https://img.shields.io/pypi/pyversions/dilax
[pypi-version]:             https://img.shields.io/pypi/v/dilax
<!-- prettier-ignore-end -->
