Metadata-Version: 2.1
Name: probjax
Version: 0.1.0
Summary: Jax library for probabilistic computations
Author-email: Manuel Gloeckler <manuel.gloeckler@uni-tuebingen.de>
License: MIT
Keywords: probabilistic,jax,computation
Classifier: License :: OSI Approved :: MIT License
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: Operating System :: POSIX :: Linux
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE.txt
Requires-Dist: numpy>=2.1.2
Requires-Dist: matplotlib
Requires-Dist: jax>=0.4.34
Requires-Dist: jaxlib>=0.4.34
Requires-Dist: optax
Requires-Dist: ott-jax
Requires-Dist: networkx
Requires-Dist: blackjax
Requires-Dist: sympy
Provides-Extra: dev
Requires-Dist: pytest; extra == "dev"
Requires-Dist: ruff; extra == "dev"
Requires-Dist: pytest-xdist; extra == "dev"

# Probjax

Probabilistic computation in JAX. This library is under active development and is not yet ready for use. It aims to provide a simple and flexible way to build probabilistic models and perform inference in then. It provides the following set of tools:
- **Core**: A set of core function transformations and primitives useful for building probabilistic models.
    - **Traceing**: Tracing and manipulation of function traces. (Very incomplete)
    - **Automatic inversion**: Automatic inversion of functions. (Rather complete, with some limitations)
    - **Automatic log_prob**: Automatic computation of log-probabilities (Rather incomplete). Automatic computation of log-probabilities of transformed distributions (Rather complete, through automatic inversion and logdet).
- **Distributions**: A set of distributions with support for sampling, log-probability and more.
- **Inference**: Some inference algorithms. (incomplete)
- **Neural networks**: Some neural network layers and models. Based on [Haiku](www.github.com/deepmind/dm-haiku). Here a classical layers as Transformers, Resnets or U-Nets. But also specialised layers for normalising flows, such as coupling layers, autoregressive layers, etc. (complete)
- **Utilities**: Some utilities for numerical computation i.e. odeint, sdeint, etc. (complete)

## Installation

Probjax can be installed using pip:

```bash
pip install -e probjax
```

Additionally, you can install benchmark scripts using:

```bash
pip install -e probjax/scoresbibm
```
