Metadata-Version: 2.4
Name: entropy-profiler
Version: 0.2.1
Summary: Extract, analyze, and visualize entropy profiles from transformer models using the logit-lens technique.
Project-URL: Homepage, https://github.com/TheGitCommit/entropy-profiler
Project-URL: Documentation, https://github.com/TheGitCommit/entropy-profiler/blob/master/README.md
Project-URL: Repository, https://github.com/TheGitCommit/entropy-profiler
Project-URL: Issues, https://github.com/TheGitCommit/entropy-profiler/issues
Author: entropy-profiler contributors
License-Expression: MIT
License-File: LICENSE
Keywords: entropy,interpretability,logit-lens,machine-learning,nlp,profiling,transformer
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Typing :: Typed
Requires-Python: >=3.10
Requires-Dist: matplotlib>=3.7.0
Requires-Dist: numpy>=1.24.0
Requires-Dist: scikit-learn>=1.3.0
Requires-Dist: scipy>=1.11.0
Requires-Dist: seaborn>=0.12.0
Requires-Dist: torch>=2.1.0
Requires-Dist: transformers>=4.40.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == 'dev'
Requires-Dist: ruff>=0.4.0; extra == 'dev'
Provides-Extra: notebook
Requires-Dist: datasets>=3.0.0; extra == 'notebook'
Requires-Dist: ipywidgets>=8.0.0; extra == 'notebook'
Requires-Dist: jupyter>=1.0.0; extra == 'notebook'
Requires-Dist: notebook>=7.0.0; extra == 'notebook'
Provides-Extra: quantize
Requires-Dist: accelerate>=0.25.0; extra == 'quantize'
Requires-Dist: bitsandbytes>=0.41.0; extra == 'quantize'
Description-Content-Type: text/markdown

# entropy-profiler

Extract, analyse, and visualize entropy profiles from transformer models using
the logit-lens technique.

`entropy-profiler` computes per-layer Shannon or Rényi entropy by passing
hidden states through the model's own unembedding head (layer norm + lm_head).
It works on any HuggingFace `CausalLM` without architecture-specific hooks.

```python
from entropy_profiler import EntropyProfiler, plot_profile
import torch

profiler = EntropyProfiler("gpt2", dtype=torch.float32)
profile = profiler.profile_text("The meaning of life is", max_new_tokens=32)
plot_profile(profile)
```

---

## Installation

### From source (recommended for development)

```bash
git clone https://github.com/TODO/entropy-profiler
cd entropy-profiler

# Using uv (fast, handles venvs automatically)
uv sync                      # core dependencies
uv sync --extra notebook     # + Jupyter support
uv sync --extra dev          # + pytest, ruff

# Or using pip
pip install -e .
pip install -e ".[quantize]"   # + 8-bit/4-bit quantization (bitsandbytes, accelerate)
pip install -e ".[notebook]"
pip install -e ".[dev]"
```

### From PyPI (once published)

```bash
pip install entropy-profiler
```

---

## Quick Start

### Profile a single prompt

```python
from entropy_profiler import EntropyProfiler, plot_profile
import torch

profiler = EntropyProfiler("gpt2", dtype=torch.float32)
profile = profiler.profile_text("The capital of France is", max_new_tokens=32)

print(profile.entropy.shape)    # (n_tokens, n_layers)
print(profile.mean_profile())   # (n_layers,) tensor
plot_profile(profile)
```

### Profile multiple prompts

```python
from entropy_profiler import plot_aggregated

agg = profiler.profile_batch([
    "The stock market experienced significant",
    "In quantum mechanics, the wave function",
    "Modern neural networks learn by",
], max_new_tokens=24)

print(agg.to_matrix().shape)    # (3, n_layers)
plot_aggregated(agg)
```

### Compare prompts with distances

```python
from entropy_profiler import profile_distance

p1 = profiler.profile_text("Water boils at", max_new_tokens=24)
p2 = profiler.profile_text("Once upon a time", max_new_tokens=24)

result = profile_distance(p1, p2, metric="jsd")
print(f"JSD distance: {result.aggregate:.4f}")
```

### Analyse layer dynamics

```python
from entropy_profiler import LayerAnalyzer

profile, hidden_states = profiler.profile_text_with_states(
    "Hello world", max_new_tokens=32
)
analyzer = LayerAnalyzer(profiler, profile, hidden_states=hidden_states)

print(analyzer.layer_entropy())          # (n_layers,)
print(analyzer.information_velocity())   # (n_layers,)
print(analyzer.layer_mi(method="cka"))   # (n_layers,)
```

---

## Core Concepts

### Logit-Lens Decoding

At each transformer layer, the hidden state is projected through the model's
final layer norm and language model head to produce a vocabulary distribution.
The entropy of this distribution measures how "decided" the model is at that
layer — low entropy means a peaked distribution (confident prediction), high
entropy means a flat distribution (uncertain).

### Entropy Profiles

An **entropy profile** is a matrix of shape `(n_tokens, n_layers)` where each
entry is the entropy of the vocabulary distribution at that token position and
layer depth. The **mean profile** `(n_layers,)` averages across tokens to give
a single curve showing how entropy evolves through the network.

### Why Rényi Entropy?

Shannon entropy (`alpha=1`) is the default, but Rényi entropy at other orders
provides complementary views:
- `alpha < 1` — emphasises rare events (tail sensitivity)
- `alpha = 1` — Shannon entropy (standard)
- `alpha = 2` — collision entropy (sensitive to mode)
- `alpha > 2` — increasingly dominated by the most probable token

---

## API Reference

### Core Module (`entropy_profiler.profiler`)

| Symbol | Description |
|--------|-------------|
| `EntropyProfiler(model, dtype, alpha, layer_stride, load_in_8bit, load_in_4bit)` | Main class. Loads model, runs generation, computes entropy. Use `load_in_8bit` or `load_in_4bit` for quantized loading (requires bitsandbytes). |
| `EntropyProfile` | Dataclass: `entropy`, `token_ids`, `layer_indices`, `alpha`, `model_name`, `metadata`. |
| `AggregatedProfile` | Collection of profiles with `mean_profile()` and `to_matrix()`. |
| `shannon_entropy(probs)` | `H(p) = -sum(p log p)` on the last dimension. |
| `renyi_entropy(probs, alpha)` | Rényi entropy of order α. Falls back to Shannon when α ≈ 1. |

**`EntropyProfiler` methods:**

| Method | Returns | Description |
|--------|---------|-------------|
| `profile_text(prompt, max_new_tokens, ...)` | `EntropyProfile` | Profile generated text. |
| `profile_text_with_states(prompt, ...)` | `(EntropyProfile, Tensor)` | Profile + raw hidden states. |
| `profile_batch(prompts, ...)` | `AggregatedProfile` | Profile multiple prompts. |
| `unload()` | `None` | Free model memory. |

**`EntropyProfile` attributes and methods:**

| Member | Type | Description |
|--------|------|-------------|
| `entropy` | `Tensor (n_tokens, n_layers)` | Per-token, per-layer entropy. |
| `token_ids` | `Tensor (n_tokens,)` | Generated token IDs. |
| `n_layers` | `int` | Number of profiled layers. |
| `n_tokens` | `int` | Number of profiled tokens. |
| `mean_profile()` | `Tensor (n_layers,)` | Mean entropy at each layer. |
| `to_numpy()` | `ndarray` | Convert to NumPy (float32). |

### Distances (`entropy_profiler.distances`)

| Function | Type | Description |
|----------|------|-------------|
| `profile_distance(p1, p2, metric, aggregation)` | `DistanceResult` | Unified entry point. |
| `pairwise_distances(profiles, metric)` | `ndarray (N, N)` | Symmetric distance matrix. |
| `jsd_layer(p1, p2, n_bins)` | `ndarray (n_layers,)` | Per-layer Jensen-Shannon divergence. |
| `wasserstein_layer(p1, p2)` | `ndarray (n_layers,)` | Per-layer Wasserstein-1 distance. |
| `fisher_rao_distance(p1, p2)` | `float` | Geodesic on probability simplex. |
| `srvf_distance(p1, p2)` | `float` | Elastic SRVF curve distance. |

**Available metrics for `profile_distance`:** `"jsd"`, `"wasserstein"`, `"fisher_rao"`, `"srvf"`.

**Aggregation methods:** `"mean"`, `"max"`, `"sum"` (for layer-wise metrics).

### Layer Analysis (`entropy_profiler.analysis`)

| Symbol | Description |
|--------|-------------|
| `LayerAnalyzer(profiler, profile, hidden_states)` | Per-layer metric computation. |

Additional functions available via `from entropy_profiler.analysis import ...`:
`compare_models`, `plot_layer_importance`, `plot_information_plane`, `plot_velocity_entropy`.

**`LayerAnalyzer` methods:**

| Method | Returns | Description |
|--------|---------|-------------|
| `layer_entropy()` | `ndarray (n_layers,)` | Mean Shannon entropy per layer. |
| `information_velocity()` | `ndarray (n_layers,)` | Wasserstein between consecutive layers. |
| `distance_to_output()` | `ndarray (n_layers,)` | Fisher-Rao distance to final layer. |
| `jsd_to_output(n_bins)` | `ndarray (n_layers,)` | JSD from each layer to final. |
| `layer_mi(method)` | `ndarray (n_layers,)` | MI with final layer (Rényi or CKA). |
| `layer_importance()` | `dict` | All four non-MI metrics. |

### Visualization (`entropy_profiler.viz`)

| Function | Description |
|----------|-------------|
| `plot_profile(profile, ax, ...)` | Line plot with ±1 std fill. |
| `plot_profiles(profiles, labels, ...)` | Overlay multiple profiles. |
| `plot_heatmap(profile, ax, ...)` | Token × layer entropy heatmap. |
| `plot_aggregated(agg, ax, ...)` | Aggregated mean ± std curve. |
| `plot_cluster(profiles, labels, method, feature, metric, ...)` | 2D scatter via t-SNE/UMAP/PCA. |

### Estimators (`entropy_profiler.estimators`)

| Symbol | Description |
|--------|-------------|
| `MatrixRenyiMI(alpha, device)` | Matrix-based Rényi MI via Gram matrices. Used by `LayerAnalyzer.layer_mi()`. |

---

## Supported Models

Any HuggingFace `AutoModelForCausalLM` is supported. The profiler automatically
detects the unembedding architecture:

| Model Family | Layer Norm Path | Status |
|-------------|-----------------|--------|
| GPT-2 | `transformer.ln_f` | Tested |
| LLaMA / LLaMA 2 / LLaMA 3 | `model.norm` | Tested |
| Mistral | `model.norm` | Tested |
| Gemma / Gemma 2 | `model.norm` | Tested |
| Qwen / Qwen 2 | `model.norm` | Tested |
| OPT | `model.norm` (fallback) | Tested |

To add a new architecture, add a resolution pattern to
`_get_unembedding()` in `entropy_profiler/profiler.py`.

### Tips for large models

```python
# Use float16 for 7B+ models to fit in GPU memory
profiler = EntropyProfiler("meta-llama/Llama-2-7b-hf", dtype=torch.float16)

# Load in 8-bit or 4-bit to fit even larger models (requires: pip install bitsandbytes accelerate)
profiler = EntropyProfiler("meta-llama/Llama-2-7b-hf", load_in_8bit=True)
profiler = EntropyProfiler("meta-llama/Llama-2-7b-hf", load_in_4bit=True)

# Profile every other layer to reduce computation
profiler = EntropyProfiler("gpt2", layer_stride=2)

# Use context manager to auto-unload
with EntropyProfiler("gpt2") as profiler:
    profile = profiler.profile_text("Hello world")
```

---

## Design Decisions

**No hooks.** HuggingFace's `output_hidden_states=True` returns all hidden
states without hook infrastructure. This works across all CausalLM
architectures with zero architecture-specific code.

**Logit-lens, not probing.** The unembedding head is the model's own decoder.
No training of linear probes is needed — the entropy values are directly
interpretable as "how peaked is the vocabulary distribution at this layer."

**Float32 entropy.** Entropy is always computed in float32 regardless of model
dtype, avoiding numerical issues with half-precision softmax.

**Dataclass outputs.** `EntropyProfile` and `DistanceResult` are plain
dataclasses — easy to inspect, serialize, and compose.

---

## Notebooks

| Notebook | Description |
|----------|-------------|
| `exploration.ipynb` | Multi-model exploration: entropy curves, heatmaps, velocities, distances, clustering. Requires GPU and gated-model access. |
| `api_tour.ipynb` | Complete API tour exercising every public function with GPT-2. |

```bash
uv sync --extra notebook
uv run jupyter notebook notebooks/
```

---

## Development

```bash
# Install dev dependencies
uv sync --extra dev

# Lint
uv run ruff check .
uv run ruff check --fix .

# Test
uv run pytest

# Run a script without activating venv
uv run python your_script.py
```

---

## License

MIT
