Metadata-Version: 2.4
Name: jmstate
Version: 0.17.2
Summary: Joint modeling with automatic differentiation
Author: Félix Laplante
Project-URL: Source, https://github.com/felixlaplante/jmstate
Classifier: Programming Language :: Python :: 3
Classifier: Operating System :: POSIX :: Linux
Classifier: Operating System :: MacOS
Classifier: Operating System :: Microsoft :: Windows
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy
Requires-Dist: torch
Requires-Dist: matplotlib
Requires-Dist: scikit-learn
Requires-Dist: rich
Requires-Dist: tqdm
Dynamic: license-file

# jmstate

**jmstate** is a Python package for **nonlinear multi-state joint modeling** of longitudinal and time-to-event data. Built on [PyTorch](https://pytorch.org/), it enables flexible specification of regression and link functions — including neural networks — while still offering built-in parametric baseline hazards and utilities for inference and prediction.

The package implements the methodology from:

> **A General Framework for Joint Multi-State Models**
> Félix Laplante & Christophe Ambroise (2025) — [arXiv:2510.07128](https://arxiv.org/abs/2510.07128)

---

## Installation

```bash
pip install jmstate
```

**Requirements:** Python ≥ 3.10, PyTorch, scikit-learn, NumPy, Matplotlib, rich, tqdm.

---

## Documentation

Full API reference and tutorials: [jmstate documentation](https://felixlaplante0.github.io/jmstate/)

---

## The Model

`jmstate` fits a **joint model** that links a longitudinal biomarker process to multi-state event history through shared individual random effects.

### Longitudinal sub-model

Individual observations follow

$$y_{ij} = h(t_{ij}, \psi_i) + \epsilon_{ij}, \qquad \epsilon_{ij} \sim \mathcal{N}(0, R)$$

where $h$ is a user-defined regression function (e.g. bi-exponential, logistic) and individual parameters are defined via

$$\psi_i = f(\gamma, X_i, b_i), \qquad b_i \sim \mathcal{N}(0, Q)$$

with $\gamma$ fixed population-level effects, $X_i$ covariates, and $b_i$ individual random effects.

### Multi-state sub-model

Let $G = (V, E)$ be a directed graph, where $V$ denotes the set of states and $E \subseteq V \times V$ the set of admissible transitions. The graph encodes all possible paths of the multi-state process, allowing for competing, recurrent, or absorbing transitions. The hazard for a transition $k \to k'$ at time $t$ given entry time $t_0$ satisfies

$$\lambda^{k \to k'}(t_0, t) = \lambda_0^{k \to k'}(t_0, t) \exp\left( \alpha^{k \to k'} g^{k \to k'}(t, \psi_i) + \beta^{k \to k'} X_i \right),$$

where $\lambda_0^{k \to k'}$ are parametric baseline hazards, $g^{k \to k'}$ are link functions acting as a bridge between the longitudinal and the semi-Markov multi-state processes, and $\alpha^{k \to k'}$, $\beta^{k \to k'}$ are transition-specific coefficients.

The model supports **arbitrary state graphs** (recurrent, absorbing, monotone, etc.) under a semi-Markov assumption.

### Estimation

Parameters are estimated by maximising the observed-data log-likelihood using the **Fisher identity**

$$\nabla_\theta \log \mathcal{L}(\theta; x) = \mathbb{E}_{b \sim p(\cdot \mid x, \theta)}\left[ \nabla_\theta \log \mathcal{L}(\theta; x, b) \right],$$

where $\mathcal{L}(\theta; x, b)$ is the complete likelihood of the data given the parameters and random effects.

This gradient is approximated via a **Metropolis-Within-Gibbs MCMC** sampler over the random effects, combined with a stochastic gradient ascent step. Convergence is monitored via an $R^2$-based stationarity test.

---

## Quick Start

### Step 1 — Define the model design

```python
import torch
from jmstate.types import ModelDesign


# Individual parameters
def indiv_effects_fn(
    fixed: torch.Tensor, x: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    return fixed * torch.exp(b)  # (..., n, q)


# PK function: bi-exponential biomarker
def pk_fn(t: torch.Tensor, indiv_params: torch.Tensor, D: float = 1.0):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    conc = A * (torch.exp(-k * t) - torch.exp(-ka * t))
    return conc.unsqueeze(-1)


# PK integral function: bi-exponential cumulative link
def pk_integral_fn(t: torch.Tensor, indiv_params: torch.Tensor):
    A, k, ka = indiv_params.chunk(3, dim=-1)
    integral = A * (1 - torch.exp(-k * t)) / k - (1 - torch.exp(-ka * t)) / ka
    return integral.unsqueeze(-1)


# Define the model design
design = ModelDesign(
    indiv_effects_fn,
    regression_fn=pk_fn,
    link_fns={
        (1, 1): pk_integral_fn,
        (1, 2): pk_integral_fn,
    },
)
```

### Step 2 — Set initial parameters

```python
from jmstate.functions.base_hazards import Exponential
from jmstate.types import ModelParameters, PrecisionParameters

# Define simple initial parameters
params = ModelParameters(
    torch.ones(3),
    PrecisionParameters.from_covariance(torch.eye(3), "diag"),
    PrecisionParameters.from_covariance(torch.eye(1), "spherical"),
    {(1, 1): Exponential(1.0), (1, 2): Exponential(1.0)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
    {(1, 1): torch.zeros(1), (1, 2): torch.zeros(1)},
)
```

### Step 3 — Prepare data

```python
from jmstate.types import ModelData

data = ModelData(
    x,  # (n, p) covariate matrix
    t,  # (m,) or (n, m) measurement times; NaN-pad if variable
    y,  # (n, m, d) longitudinal observations; NaN-pad if variable
    trajectories,  # list[list[tuple[float, Any]]]
    c,  # (n, 1) right-censoring times
)
```

Each trajectory is a chronologically ordered list of `(time, state)` tuples representing the individual's event history.

### Step 4 — Fit the model

```python
import matplotlib.pyplot as plt
from jmstate import MultiStateJointModel

optimizer = torch.optim.Adam(params.parameters(), lr=0.1)
model = MultiStateJointModel(design, params, optimizer)

metrics = model.fit(data)
```

### Step 5 — Print and plot the results

```python
from jmstate.utils import plot_mcmc_diagnostics, plot_params_history

# Compute and print summary statistics (nullity Wald statistics, p-values, AIC, BIC, etc.)
model.compute_summary().summary()

# Plot parameter history (stochastic optimization)...
plot_params_history(model)
plt.show()

# ...and MCMC sampler diagnostics
plot_mcmc_diagnostics(model)
plt.show()
```

---

## License

See [LICENSE](LICENSE).
