Metadata-Version: 2.4
Name: reverse-attention
Version: 0.1.0
Summary: Reverse Attention Beam Search for tracing attention paths in transformer models
Author: Abhishek Maiti
License: MIT
Project-URL: Homepage, https://github.com/abhishekmaiti/reverse-attention
Project-URL: Repository, https://github.com/abhishekmaiti/reverse-attention
Keywords: attention,transformer,visualization,interpretability,qwen2
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: transformers>=4.35.0
Requires-Dist: numpy>=1.24.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: pytest-cov; extra == "dev"
Provides-Extra: docs
Requires-Dist: mkdocs>=1.5.0; extra == "docs"
Requires-Dist: mkdocs-material>=9.5.0; extra == "docs"
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
Dynamic: license-file

# Reverse Attention Tracer (RAT)

A Python package for tracing attention paths backward through transformer models, with interactive D3.js Sankey visualization.

## Features

- **Reverse attention tracing**: Trace which tokens most influence a target token by following attention paths backward
- **Beam search**: Efficiently explore multiple high-probability paths through the attention matrix
- **Interactive visualization**: D3.js-powered Sankey diagrams with zoom, pan, and click-to-highlight
- **Qwen2 support**: Optimized for Qwen2 family models (works with other HuggingFace transformers)

## Installation

```bash
pip install reverse-attention
```

Or install from source:

```bash
git clone https://github.com/ovshake/reverse-attention
cd reverse-attention
pip install -e .
```

## Quick Start

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from reverse_attention import ReverseAttentionTracer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")

# Create tracer
tracer = ReverseAttentionTracer(model, tokenizer)

# Trace attention from the last token
result = tracer.trace_text("The quick brown fox jumps over the lazy dog.")

# Print top attention paths
for i, path in enumerate(result.paths_text):
    print(f"Beam {i+1}: {path}")

# Generate interactive visualization
tracer.render_html(result, "output/", open_browser=True)
```

## API Reference

### ReverseAttentionTracer

The main class for tracing attention paths.

```python
tracer = ReverseAttentionTracer(model, tokenizer, device=None, dtype=None)
```

#### Parameters

- `model`: HuggingFace transformer model
- `tokenizer`: Corresponding tokenizer
- `device`: Device to run on (defaults to model's device)
- `dtype`: Data type for computation (defaults to model's dtype)

### trace()

Trace attention paths backward from a target position.

```python
result = tracer.trace(
    input_ids,              # Input token IDs [1, seq_len]
    target_pos=-1,          # Position to trace from (supports negative indexing)
    attention_mask=None,    # Optional attention mask
    layer=-1,               # Layer index (supports negative indexing)
    top_beam=5,             # Number of beams to keep
    top_k=5,                # Top-k predecessors per step
    min_attn=0.0,           # Minimum attention threshold
    agg_heads="mean",       # Head aggregation: "mean", "max", "none"
    length_norm="avg_logprob",  # Score normalization
    stop_at_bos=True,       # Stop at BOS tokens
    bos_token_id=None,      # Override BOS token ID
)
```

### trace_text()

Convenience method that tokenizes text before tracing.

```python
result = tracer.trace_text(
    "Your text here",
    target_pos=-1,
    **kwargs  # Same as trace()
)
```

### render_html()

Generate interactive HTML visualization.

```python
html_path = tracer.render_html(
    result,                 # TraceResult from trace()
    out_dir="output/",      # Output directory
    open_browser=False,     # Open in browser after generation
)
```

### TraceResult

The result object returned by `trace()`:

- `seq_len`: Sequence length
- `target_pos`: Target position (resolved to positive index)
- `layer`: Layer index (resolved to positive index)
- `top_beam`: Number of beams used
- `top_k`: Top-k value used
- `tokens`: List of all tokens in sequence
- `beams`: List of `BeamPath` objects
- `sankey`: `SankeyData` for visualization
- `paths_text`: Human-readable path descriptions

### BeamPath

A single attention path:

- `positions`: Token positions in sequence
- `tokens`: Token strings
- `token_ids`: Token IDs
- `edge_attns`: Attention weights along edges
- `score_raw`: Raw cumulative log score
- `score_norm`: Length-normalized score

## Score Normalization

The `length_norm` parameter controls how path scores are normalized:

- `"none"`: No normalization (raw cumulative log probability)
- `"avg_logprob"`: Divide by path length (geometric mean, default)
- `"sqrt"`: Divide by sqrt(path length)
- `"pow:α"`: Divide by path_length^α (e.g., `"pow:0.7"`)

## Head Aggregation

The `agg_heads` parameter controls how attention heads are combined:

- `"mean"`: Average attention across all heads (default)
- `"max"`: Maximum attention across heads
- `"none"`: Keep all heads separate (returns 3D attention tensor)

## Example Script

Run the demo script:

```bash
python examples/demo_qwen2.py --text "Your text here" --open-browser
```

Options:

- `--model`: Model name or path (default: Qwen/Qwen2-0.5B)
- `--text`: Text to trace
- `--target-pos`: Target position (default: -1)
- `--layer`: Layer index (default: -1)
- `--top-beam`: Number of beams (default: 5)
- `--top-k`: Top-k predecessors (default: 5)
- `--output`: Output directory (default: output)
- `--open-browser`: Open visualization in browser
- `--device`: Device to use (default: auto)

## Visualization Features

The generated HTML visualization includes:

- **Zoom/Pan**: Scroll to zoom, drag to pan
- **Click to highlight**: Click nodes to highlight connected paths
- **Beam filter**: Dropdown to filter by specific beam
- **Info panel**: Click elements to see details (position, token, attention weights)
- **Color coding**: Beams are color-coded for easy identification

## Development

Install dev dependencies:

```bash
pip install -e ".[dev]"
```

Run tests:

```bash
pytest tests/ -v
```

## License

MIT License
