Metadata-Version: 2.4
Name: dima
Version: 0.1.2
Summary: DIMA: Diffusion–Intrinsic Manifold Autoencoder
Author: Julio Candanedo
License: MIT
Requires-Python: >=3.10
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy>=1.24
Requires-Dist: scipy>=1.10
Requires-Dist: jax>=0.4.20
Requires-Dist: jaxlib>=0.4.20
Requires-Dist: flax>=0.8
Requires-Dist: optax>=0.2
Requires-Dist: huggingface_hub>=0.22
Requires-Dist: faiss-cpu>=1.7.4
Provides-Extra: viz
Requires-Dist: matplotlib>=3.7; extra == "viz"
Provides-Extra: dev
Requires-Dist: pytest>=7; extra == "dev"
Requires-Dist: ruff>=0.4; extra == "dev"
Requires-Dist: mypy>=1.8; extra == "dev"
Requires-Dist: pre-commit>=3.6; extra == "dev"
Requires-Dist: build>=1.2; extra == "dev"
Requires-Dist: twine>=5; extra == "dev"
Dynamic: license-file

# DIMA — Diffusion–Intrinsic Manifold Autoencoder

**DIMA** combines three components into a practical, scalable manifold autoencoder:

- **DMAP**: Diffusion Maps encoder (builds a kNN graph in ambient space and embeds points into diffusion coordinates).
- **GPLM**: Nyström / inducing-point kernel ridge decoder (maps latent diffusion coordinates back to ambient space).
- **DDPM**: Latent diffusion model (optional) that learns a generative prior over *normalized* latents and can “refine” latents.

> Philosophy: keep **DMAP + GPLM** fast on CPU/RAM (NumPy/SciPy, sparse ops), and optionally run **DDPM** on JAX (CPU or GPU).

---

## Installation

### CPU-only (recommended to start)
```bash
pip install dima
```

### Optional extras

* Faster kNN search (FAISS):

```bash
pip install dima[faiss]
```

* Hugging Face upload/load:

```bash
pip install dima[hf]
```

> **GPU note (JAX):** `pip install dima` installs CPU `jaxlib` by default.
> For GPU, install the correct JAX wheel for your CUDA/ROCm setup first (per JAX docs), then install `dima`.

---

## Quickstart

```python
import numpy as np
import jax
from dima import DIMA

R_iX = np.random.randn(5000, 16).astype(np.float32)

dima = DIMA(R_iX)                 # trains DMAP, GPLM, DDPM with defaults

Z = dima.encode(R_iX[:10])        # ambient -> normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)  # latent -> ambient (np), no DDPM
X_ref = dima.decode(Z, refine=True, t_start=10)  # DDPM refinement

X_gen = dima.sample(1000)         # unconditional samples -> ambient (np)
```

---

## What DIMA trains

### 1) DMAP encoder (Diffusion Maps)

DMAP builds a kNN graph over ambient data $R_{iX}\in\mathbb{R}^{N\times D}$ and returns diffusion coordinates
$R_{ix}\in\mathbb{R}^{N\times d}$.

Key idea (no theory): DMAP produces a **geometry-aware** latent space where nearby points on the manifold remain nearby
in diffusion distance.

**Main knobs (DMAP):**

* `d` *(int)*: latent dimension.
* `k` *(int)*: neighbors in the kNN graph. If `None`, a heuristic is used.
* `beta` / `β` *(float)*: kernel sharpness for the RBF affinity
  $K_{ij}=\exp{-\beta |x_i-x_j|^2/\varepsilon}$.
* `eps` / `ε` *(float or None)*: kernel bandwidth. If `None`, estimated from kNN distances (median heuristic).
* `alpha` / `α` *(float)*: density normalization exponent. Common values: `0.0` (none) or `1.0` (often robust).
* `t` *(float)*: diffusion time exponent (scales eigenvalues as $\lambda^t$). Often `0.5` or `1.0`.
* `drop_trivial` *(bool)*: drops the top eigenvector/eigenvalue (the constant mode).
* `sym` *(str)*: symmetrization mode for sparse kNN kernel graph: `"max"` or `"mean"`.
* `ann_backend` *(str)*: `"auto" | "faiss" | "pynndescent" | "sklearn" | "brute"`.

**Typical DMAP presets**

* Fast-ish and stable: `k=128..512`, `alpha=1.0`, `t=0.5`, `drop_trivial=True`.
* If your data is very noisy, try larger `k` and/or larger `eps_mul`.

---

### 2) GPLM decoder (Nyström kernel ridge / inducing GP)

GPLM learns a mapping from latents back to ambient:

* Inputs: latent training points $R_{ix}\in\mathbb{R}^{N\times d}$
* Targets: ambient training points $R_{iX}\in\mathbb{R}^{N\times D}$

Instead of a full $N\times N$ kernel solve, GPLM uses **inducing points** $Z_{mx}$ with $m\ll N$ and solves a reduced system:

* Build affinities $C_{im}=\exp{-\beta|R_{ix}-Z_{mx}|^2/\varepsilon}$
* Solve a stabilized kernel ridge / GP mean system to obtain weights $M_{mX}$
* Predict: $\hat R_{aX}=C_{am}M_{mX}+\mu_X$

**Main knobs (GPLM):**

* `m` *(int)*: number of inducing points. Bigger → better accuracy, more compute.
* `inducing` *(str)*: inducing strategy:

  * `"kmeans_medoids"` (default): kmeans centers snapped to nearest training latent.
  * `"fps"`: farthest-point sampling (space-filling).
  * `"random_subset"`: fastest.
  * `"given"`: use provided `Z_mx`.
* `sigma2` / `σ2` *(float)*: ridge regularization. Too small can overfit / cause instability; too large blurs reconstructions.
* `jitter` *(float)*: tiny diagonal stabilizer for Cholesky.
* `eps` / `ε` *(float or None)*: RBF bandwidth in latent space. If `None`, estimated from latent kNN distances.
* `k_eps` / `κ_eps` *(int)*: neighbors used for the $\varepsilon$ heuristic.
* `pred_k` / `pred_κ` *(int or None)*: prediction-time inducing neighbors:

  * `None` means use all inducing points (best accuracy).
  * a small number (e.g. `128` or `256`) speeds inference (slightly lower accuracy).
* `whiten_latent` *(bool)*: optionally standardize latent dimensions before kernel computation.
* `center_X` *(bool)*: subtract and re-add ambient mean (usually helpful).
* `fit_block` *(int)*: block size for streaming $C^T C$ accumulation (memory/perf knob).
* `ann_backend` *(str)*: same options as DMAP.

**Typical GPLM presets**

* Accurate: `m=1024..4096`, `pred_k=None`, `sigma2=1e-5` (tune).
* Faster inference: set `pred_k=128..512`.

---

### 3) DDPM latent diffusion (optional generative prior)

DDPM learns a distribution over **normalized** latents:
[
Z = \frac{R_x - \mu}{\sigma}
]
and can:

* **sample** new latents,
* **refine** a given latent by projecting it onto the learned latent manifold/prior.

In DIMA, DDPM operates purely in latent space (dimension `d`), so it’s lightweight compared to image DDPMs.

**Main knobs (DDPM):**

* `T` *(int)*: number of diffusion steps. Common: `100..1000`. DIMA default: `200`.
* `hidden_dim` *(int)*: MLP width.
* `t_embed_dim` *(int)*: time embedding dimension.
* `n_iter` *(int)*: training iterations (more is better for sampling quality).
* `batch_size` *(int)*: training batch size.
* `learning_rate` *(float)*: Adam learning rate.
* `ema_decay` *(float)*: EMA smoothing for stable sampling (typical: `0.999`).
* `beta_max` *(float)*: caps noise schedule (stability knob).
* `eps` *(float)*: numerical stabilizer.
* `verbose_every` *(int)*: progress prints.

**Refinement knobs (during decoding):**

* `refine` *(bool)*: enable/disable DDPM refinement.
* `t_start` *(int)*: how strongly to “project” using reverse diffusion.

  * small (`1..10`) = gentle projection
  * larger (`20..100`) = stronger projection (can oversmooth or drift if DDPM undertrained)
* `add_noise` *(bool)*: whether to forward-noise before reverse steps.

---

## API

### Core class

```python
from dima import DIMA
dima = DIMA(R_iX)
```

### Encode / decode

```python
Z = dima.encode(X)                      # (B,d) normalized latent (jnp)
X_hat = dima.decode(Z, refine=False)    # (B,D) reconstruction (np)
X_ref = dima.decode(Z, refine=True, t_start=10)  # refined decode
```

### Polymorphic call

```python
Z = dima(X)     # if X.shape[-1] == D
X = dima(Z)     # if Z.shape[-1] == d
```

### Sampling

```python
X_gen = dima.sample(1000)            # ambient samples (np)
Z_gen = dima.sample(1000, decode=False)  # latent samples (jnp)
```

---

## Configuration patterns

### 1) Use defaults (cleanest)

```python
dima = DIMA(R_iX)
```

### 2) Pass only a few knobs

```python
dima = DIMA(
    R_iX,
    d=64,
    dmap_kwargs=dict(k=256, alpha=1.0, t=0.5),
    gplm_kwargs=dict(m=2048, sigma2=1e-5, pred_k=256),
    ddpm_kwargs=dict(n_iter=50_000, hidden_dim=256),
)
```

---

## Saving / Loading

### Local

```python
dima.save_local("dima.msgpack", "config.json")
dima2 = DIMA.load_local("dima.msgpack", ddpm_device="auto")
```

### Hugging Face (optional)

```python
dima.upload_to_huggingface(repo_id="username/dima-model", hf_token="...")

dima3 = DIMA.load_from_huggingface("username/dima-model", ddpm_device="auto")
```

---

## Performance notes

* **DMAP**: main costs are kNN search + sparse eigensolve.

  * kNN is faster with FAISS.
  * Increasing `k` increases graph density and compute.
* **GPLM**: training cost roughly scales with $N\cdot m$ (streamed by `fit_block`).

  * Inference cost is $B\cdot m$ if `pred_k=None`, or $B\cdot \text{pred_k}$ if using inducing kNN.
* **DDPM**: scales with latent dimension `d`, steps `T`, and training iterations.

  * GPU helps but CPU works for smaller `d` and fewer iterations.

---

## Troubleshooting

### “Unknown backend: gpu”

Your JAX install only has CPU. Use:

```python
dima = DIMA(R_iX, ddpm_device="cpu")
```

or install a GPU-enabled JAX build.

### Reconstructions are blurry / low quality

* Increase `gplm_kwargs["m"]`
* Decrease `gplm_kwargs["sigma2"]` slightly (careful: too small can destabilize)
* Set `pred_k=None` (use all inducing points)

### DDPM refinement makes reconstructions worse

* Reduce `t_start` (try `3..10`)
* Increase DDPM training (`ddpm_kwargs["n_iter"]`)
* Disable `add_noise` for gentler behavior

---

## Citations

* Diffusion Maps / DMAE inspiration: *Diffusion Map AutoEncoder (DMAE)*.
* Latent diffusion prior: *Denoising Diffusion Probabilistic Models (DDPM)*.

---

## License

MIT
