Metadata-Version: 2.4
Name: khoji
Version: 0.4.0
Summary: Fine-tune embedding models for domain-specific retrieval using LoRA
Keywords: retrieval,embeddings,fine-tuning,lora,information-retrieval,semantic-search,nlp,transformers,deep-learning
Author: Suyash Harlalka
License-Expression: MIT
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Scientific/Engineering :: Information Analysis
Classifier: Typing :: Typed
Requires-Dist: torch>=2.0
Requires-Dist: transformers>=4.36
Requires-Dist: peft>=0.7
Requires-Dist: datasets>=2.16
Requires-Dist: huggingface-hub>=0.20
Requires-Dist: pyyaml>=6.0
Requires-Dist: tqdm>=4.60
Requires-Dist: pillow>=9.0
Requires-Dist: requests>=2.28
Requires-Python: >=3.10
Project-URL: Homepage, https://github.com/suyashh94/khoji
Project-URL: Repository, https://github.com/suyashh94/khoji
Project-URL: Issues, https://github.com/suyashh94/khoji/issues
Project-URL: Documentation, https://github.com/suyashh94/khoji#readme
Description-Content-Type: text/markdown

# khoji

**Make retrieval models actually work on your data**


[Installation](#installation) | [Quick Start](#quick-start) | [Retrieval Modes](#retrieval-modes) | [Training Concepts](#training-concepts) | [Extensibility](#extensibility) | [Architecture](#architecture) | [**Tutorial &rarr;**](https://suyashh94.github.io/khoji/)

> **[Interactive Tutorial](https://suyashh94.github.io/khoji/)** — A comprehensive walkthrough covering all four retrieval modes with real before/after results, complete training code, custom model examples, and embedded retrieval visualizations.


---

Pretrained retrieval models (BERT, CLIP, BLIP-2) are trained on generic web data. They work reasonably well out of the box, but struggle on domain-specific queries — legal documents, medical images, satellite imagery, fashion products, internal knowledge bases. The standard fix is fine-tuning, but wiring together the data pipeline, negative mining, LoRA, training loop, and evaluation for retrieval is a lot of boilerplate.

khoji handles all of that. You point it at your data and a base model, and it fine-tunes a retrieval adapter using LoRA — with hard negative mining, standard IR evaluation, and support for text search, image search, and composed retrieval where queries and targets can be any mix of image and text (e.g., "find this dress but in red", or product search matching images with descriptions). Text and image search work as a single YAML config for quick experiments. All three modes offer composable Python components when you need full control.

### Three retrieval modes

| Mode | Query | Target | Models | Use case |
|------|-------|--------|--------|----------|
| **Text → Text** | text | text | BERT, BGE, sentence-transformers | Document search, FAQ matching, semantic search |
| **Text → Image** | text | image | CLIP, SigLIP | Image search from text descriptions |
| **Composed (mixed-mode)** | image, text, or both | image, text, or both | BLIP-2 | "Find me this dress but in red", product search with metadata |

### Two levels of abstraction

| Level | What you write | Best for |
|-------|---------------|----------|
| **Config-driven** | A YAML file → `run()` / `run_multimodal()` | Reproducible experiments, quick iteration |
| **Python API** | Compose individual components (model, trainer, evaluator, data) | Custom workflows, non-standard data sources, research |

### What you can plug in

| Component | Built-in | Custom |
|-----------|----------|--------|
| **Models** | Any HuggingFace model (auto-detected) | Any `nn.Module` + encode function |
| **Datasets** | BEIR (20+), Flickr30k, RSICD, FashionIQ | JSONL/TSV files, or raw Python dicts |
| **Loss functions** | Triplet Margin, InfoNCE, Contrastive | Any `(query, pos, neg) -> scalar` callable |
| **Metrics** | nDCG@k, MRR@k, Recall@k | Any `(ranked_ids, qrel, k) -> float` callable |
| **Negative mining** | Random, hard, mixed | Build your own `Triplet` / `ComposedTriplet` objects |

### Other highlights

- Parameter-efficient fine-tuning via **LoRA** (or full fine-tuning with `lora: null`)
- Auto-detection of model pooling strategy and LoRA target modules
- Iterative hard negative mining (re-mine with the fine-tuned model)
- Mixed precision training (fp16/bf16)
- All metrics implemented from scratch — no external IR evaluation frameworks
- Hardware support: CUDA, Apple Silicon (MPS), CPU

---

## Installation

**Requirements:** Python >= 3.10

```bash
# From PyPI
pip install khoji

# From source
git clone https://github.com/suyashh94/khoji.git
cd khoji
uv sync              # or: pip install -e .
uv sync --group dev  # adds pytest, ruff
```

---

## Quick Start

### Text → Text

```bash
khoji configs/minilm_scifact_full.yaml        # train + evaluate
```

```python
from khoji import ForgeConfig, run

config = ForgeConfig.from_yaml("configs/minilm_scifact_full.yaml")
result = run(config)
print(result.finetuned.metrics)   # {"ndcg@1": 0.45, "mrr@5": 0.52, ...}
print(result.adapter_dir)         # path to saved LoRA adapter
```

### Text → Image

```bash
khoji multimodal configs/clip_rsicd_full.yaml
```

```python
from khoji import MultimodalForgeConfig, run_multimodal

config = MultimodalForgeConfig.from_yaml("configs/clip_rsicd_full.yaml")
result = run_multimodal(config)
```

### Composed Retrieval (mixed-mode)

Composed retrieval has more moving parts (image loading, multiple modalities, dataset conversion) than the other modes. Start with the example scripts rather than a YAML config:

```bash
python scripts/fashioniq/download_data.py
python scripts/train_composed_retrieval_api.py --category dress
```

See [Section 3](#3-composed-retrieval-mixed-mode) for the full Python API.

---

## Retrieval Modes

### 1. Text → Text

Fine-tune text embedding models (BERT, BGE, sentence-transformers) for domain-specific document retrieval.

#### Config-driven

```yaml
model:
  name: BAAI/bge-base-en-v1.5
  # adapter_path: null          # warm-start from existing adapter
  # dtype: null                 # "fp16", "bf16", or null (fp32)

data:
  dataset: fiqa                  # BEIR dataset name or path to local directory
  split: train
  negatives: mixed               # "random", "hard", or "mixed"
  n_random: 2                    # random negatives per pair (mixed mode)
  n_hard: 1                      # hard negatives per pair (mixed mode)
  # n_negatives: 1              # negatives per pair (random/hard modes)
  # n_queries: null             # subset of queries (null = all)
  # corpus_size: null           # corpus limit for mining (null = all)
  # top_k: 50                   # top-k for hard negative mining
  # skip_top: 0                 # skip top N non-relevant (avoids false negatives)
  # mining_rounds: 1            # iterative mining rounds (re-mine with fine-tuned model)

lora:
  r: 8
  alpha: 16
  dropout: 0.1
  # target_modules: null        # auto-detected per architecture
# lora: null                    # uncomment for full fine-tuning

train:
  epochs: 3
  batch_size: 8
  grad_accum_steps: 4            # effective batch = batch_size * grad_accum_steps
  lr: 2e-5
  weight_decay: 0.01
  warmup_steps: 100              # linear warmup then linear decay
  max_grad_norm: 1.0
  max_length: 512
  loss: infonce                  # "triplet", "infonce", or "contrastive"
  margin: 0.2                    # for triplet loss
  temperature: 0.05              # for infonce loss
  # mixed_precision: null        # "fp16", "bf16", or null
  # overfit_batches: null        # set to 1 for debugging
  sanity_check_samples: 10

eval:
  # dataset: null               # null = use data.dataset
  k_values: [1, 5, 10]
  split: test
  run_before: true
  run_after: true

seed: 42
output_dir: ./forge-output
```

#### Python API (component-by-component)

```python
from khoji import (
    EmbeddingModel, Evaluator, Trainer, TrainingConfig,
    TripletDataset, LoRASettings,
    load_beir, build_mixed_negatives,
)

# 1. Load data
dataset = load_beir("fiqa", split="train")

# 2. Build training triplets
model = EmbeddingModel("BAAI/bge-base-en-v1.5")
triplets = build_mixed_negatives(dataset, model, n_random=2, n_hard=1, top_k=50)

# 3. Train
config = TrainingConfig(
    epochs=3, batch_size=8, lr=2e-5,
    lora=LoRASettings(r=8, alpha=16),
    save_dir="./my-adapter",
)
trainer = Trainer("BAAI/bge-base-en-v1.5", config)
history = trainer.train(TripletDataset(triplets))

# 4. Evaluate
evaluator = Evaluator("BAAI/bge-base-en-v1.5", adapter_path="./my-adapter")
result = evaluator.evaluate("fiqa", split="test", k_values=[1, 5, 10])
result.print()

# 5. Inference
model = EmbeddingModel("BAAI/bge-base-en-v1.5", adapter_path="./my-adapter")
embeddings = model.encode(["What is compound interest?", "How do bonds work?"])
```

#### Custom datasets

Every dataset in khoji is just three things: **queries**, a **corpus**, and **relevance judgments** (qrels) mapping which corpus items are relevant to which queries. You can provide these as local files or as Python dicts.

**Option A: Local files** — create a directory with three files:

```
my_dataset/
  queries.jsonl   # {"_id": "q1", "text": "What is compound interest?"}
  corpus.jsonl    # {"_id": "d1", "text": "Compound interest is ...", "title": "Optional Title"}
  qrels.tsv       # q1\td1\t1  (tab-separated: query_id, doc_id, relevance_score. No header.)
```

```yaml
data:
  dataset: ./my_dataset    # point to the directory
```

**Option B: Python dicts** — build from any source (database, CSV, API, etc.):

```python
from khoji import RetrievalDataset

dataset = RetrievalDataset(
    queries={"q1": "What is compound interest?"},
    corpus={"d1": "Compound interest is interest on interest.", "d2": "Unrelated doc."},
    qrels={"q1": {"d1": 1}},   # query q1 → doc d1 is relevant (score 1)
)
# Pass this to build_random_negatives(), Evaluator.evaluate(dataset=...), etc.
```

Training and evaluation datasets are independent — you can train on one and evaluate on another:

```yaml
data:
  dataset: ./my_train_data
eval:
  dataset: ./my_eval_data     # null = same as data.dataset
```

#### Supported models

Any HuggingFace model compatible with `AutoModel` / `AutoTokenizer`. Pooling is auto-detected from the model's sentence-transformers config.

| Model | Pooling | Architecture |
|-------|---------|-------------|
| `BAAI/bge-base-en-v1.5` | CLS | BERT |
| `sentence-transformers/all-MiniLM-L6-v2` | Mean | BERT |

Auto-detected LoRA targets per architecture:

| Architecture | Target Modules |
|-------------|---------------|
| BERT, RoBERTa, XLM-RoBERTa | `query`, `key`, `value` |
| DistilBERT | `q_lin`, `k_lin`, `v_lin` |
| DeBERTa (v1/v2) | `query_proj`, `key_proj`, `value_proj` |
| Mistral, LLaMA | `q_proj`, `k_proj`, `v_proj` |

---

### 2. Text → Image

Fine-tune cross-modal models (CLIP, SigLIP) where queries are text and documents are images.

#### Config-driven

```yaml
model:
  name: openai/clip-vit-base-patch32
  # adapter_path: null
  # dtype: null
  lora_target: both              # "vision", "text", or "both"

data:
  dataset: nlphuji/flickr30k     # or "arampacha/rsicd" or local path
  split: train
  negatives: random
  n_negatives: 1
  cache_dir: null                 # cache downloaded images locally
  # ... all other data params same as text-to-text

lora:
  r: 8
  alpha: 16
  dropout: 0.1

train:
  epochs: 3
  batch_size: 8
  lr: 2e-5
  max_length: 77                  # CLIP default
  loss: infonce
  temperature: 0.05
  # ... all other train params same as text-to-text

# Optional: override image preprocessing
# preprocess:
#   image_size: 224
#   mean: [0.48145466, 0.4578275, 0.40821073]
#   std: [0.26862954, 0.26130258, 0.27577711]

eval:
  k_values: [1, 5, 10]
  run_before: true
  run_after: true

output_dir: ./forge-output/flickr30k
```

```bash
khoji multimodal configs/clip_rsicd_full.yaml
```

#### Python API

```python
from khoji import (
    MultimodalEmbeddingModel, MultimodalEvaluator,
    MultimodalTrainer, MultimodalTrainingConfig,
    MultimodalTripletDataset, LoRASettings,
    load_custom_multimodal, build_random_negatives_multimodal,
)

# 1. Load data
dataset = load_custom_multimodal("./my_image_dataset")

# 2. Build triplets
triplets = build_random_negatives_multimodal(dataset, n_negatives=1)

# 3. Train
config = MultimodalTrainingConfig(
    epochs=3, batch_size=8, lr=2e-5,
    lora=LoRASettings(r=8, alpha=16),
    lora_target="both",
    save_dir="./my-clip-adapter",
    base_dir="./my_image_dataset",
)
trainer = MultimodalTrainer("openai/clip-vit-base-patch32", config)
history = trainer.train(MultimodalTripletDataset(triplets))

# 4. Evaluate
evaluator = MultimodalEvaluator("openai/clip-vit-base-patch32", adapter_path="./my-clip-adapter")
result = evaluator.evaluate(dataset=dataset, k_values=[1, 5, 10])
result.print()

# 5. Inference
model = MultimodalEmbeddingModel("openai/clip-vit-base-patch32", adapter_path="./my-clip-adapter")
text_emb = model.encode_text(["a photo of a sunset"])
img_emb = model.encode_image_sources(["sunset.jpg", "cat.jpg"], base_dir="./photos")

import torch
scores = torch.mm(text_emb, img_emb.t()).squeeze(0)
```

#### LoRA targeting

Control which encoder(s) to fine-tune:

| `lora_target` | What's trained | When to use |
|---------------|---------------|-------------|
| `both` | Text + vision encoders | Default. General domain adaptation. |
| `vision` | Vision encoder only | Text understanding is fine, images are domain-specific (satellite, medical). |
| `text` | Text encoder only | Images are generic, queries are domain-specific. |

#### Custom image datasets

Same three-file structure as text-to-text, but `corpus.jsonl` uses an `image` field instead of `text`. Image paths are relative to the dataset directory; HTTP(S) URLs also work.

```
my_image_dataset/
  queries.jsonl   # {"_id": "q1", "text": "a dog playing fetch"}
  corpus.jsonl    # {"_id": "d1", "image": "images/dog.jpg"}   (relative path or URL)
  qrels.tsv       # q1\td1\t1
  images/         # local image files
```

Or build in Python:

```python
from khoji import MultimodalRetrievalDataset

dataset = MultimodalRetrievalDataset(
    queries={"q1": "a dog playing fetch"},
    corpus={"d1": "images/dog.jpg", "d2": "images/cat.jpg"},
    qrels={"q1": {"d1": 1}},
    base_dir="./my_image_dataset",   # resolve relative paths from here
)
```

#### Built-in datasets

| Dataset | Config name | Description |
|---------|------------|-------------|
| Flickr30k | `nlphuji/flickr30k` | ~30k images, 5 captions each. General purpose. |
| RSICD | `arampacha/rsicd` | ~10k satellite/aerial images. Domain where CLIP wasn't trained. |

#### Supported models

| Model | Type | Embedding Dim |
|-------|------|--------------|
| `openai/clip-vit-base-patch32` | CLIP | 512 |
| `openai/clip-vit-large-patch14` | CLIP | 768 |
| `google/siglip-base-patch16-224` | SigLIP | 768 |

Any CLIP or SigLIP variant on HuggingFace should work.

---

### 3. Composed Retrieval (mixed-mode)

Fine-tune joint encoder models (BLIP-2) for mixed-mode retrieval: queries and targets can each be image-only, text-only, or image+text. This covers composed image retrieval ("find this dress but in red"), product search with metadata, and any scenario where items combine visual and textual information.

Composed retrieval has more moving parts than the other modes — image loading, multiple modalities per item, and dataset-specific conversion logic. **Start with the example scripts** rather than a YAML config:

```bash
python scripts/fashioniq/download_data.py
python scripts/train_composed_retrieval_api.py --category dress
```

The scripts handle FashionIQ data download, dataset conversion, training, and evaluation end-to-end. Once you understand the pipeline, use the Python API below to build your own workflows.

> **Note:** Unlike text and multimodal modes, composed retrieval does not have a `khoji composed <config.yaml>` CLI command. Use the Python API or the example scripts directly.

#### Python API

```python
from khoji import (
    JointEmbeddingModel, ComposedEvaluator,
    ComposedTrainer, ComposedTrainingConfig,
    ComposedTripletDataset, LoRASettings,
    load_custom_composed, build_random_negatives_composed,
)

# 1. Load data
dataset = load_custom_composed("./my_composed_dataset")

# 2. Build triplets
triplets = build_random_negatives_composed(dataset, n_negatives=3)

# 3. Train
config = ComposedTrainingConfig(
    epochs=5, batch_size=8, lr=2e-5,
    lora=LoRASettings(r=8, alpha=16),
    save_dir="./my-composed-adapter",
    cache_dir="./image_cache",
)
trainer = ComposedTrainer("Salesforce/blip2-itm-vit-g", config)
history = trainer.train(ComposedTripletDataset(triplets))

# 4. Evaluate
evaluator = ComposedEvaluator(
    "Salesforce/blip2-itm-vit-g", adapter_path="./my-composed-adapter"
)
result = evaluator.evaluate(dataset=dataset, k_values=[1, 5, 10, 50])
result.print()

# 5. Inference — queries and targets can be any combination of modalities
model = JointEmbeddingModel(
    "Salesforce/blip2-itm-vit-g", adapter_path="./my-composed-adapter"
)
from khoji import load_image
ref_img = load_image("reference.jpg")
query_emb = model.encode(images=[ref_img], texts=["make it red"])      # image + text query
gallery_emb = model.encode(images=[img1, img2], texts=["red dress", "blue shirt"])  # image + text targets

import torch
scores = torch.mm(query_emb, gallery_emb.t()).squeeze(0)
best_match = scores.argmax().item()
```

#### Custom composed datasets

Composed datasets differ from the other two modes: both queries and corpus items are `(image, text)` tuples. Either field can be `""` to indicate that modality is absent — enabling image-only, text-only, or image+text items.

**Local files:**

```
my_composed_dataset/
  queries.jsonl   # {"_id": "q1", "image": "imgs/ref.jpg", "text": "make it red"}
  corpus.jsonl    # {"_id": "d1", "image": "imgs/target.jpg", "text": "red dress, size M"}
  qrels.tsv       # q1\td1\t1
```

Both `"image"` and `"text"` fields are optional in queries and corpus — at least one must be present per item. This means you can mix modalities freely: image-only corpus items, text-only queries, or full image+text on both sides.

**Python dicts:**

```python
from khoji import ComposedRetrievalDataset

dataset = ComposedRetrievalDataset(
    queries={
        "q1": ("imgs/ref_dress.jpg", "make it red"),       # image + text query
        "q2": ("", "red cocktail dress"),                   # text-only query
    },
    corpus={
        "d1": ("imgs/red_dress.jpg", "red dress, size M"),  # image + text target
        "d2": ("imgs/short_sleeve_shirt.jpg", ""),           # image-only target
        "d3": ("imgs/other.jpg", "other item"),
    },
    qrels={"q1": {"d1": 1}, "q2": {"d1": 1}},
    base_dir="./my_dataset",    # resolve relative image paths from here
)
```

#### Supported models

| Model | Type | Description |
|-------|------|-------------|
| `Salesforce/blip2-itm-vit-g` | BLIP-2 | Joint image-text encoder with Q-Former. 256-dim shared space. |

Any BLIP-2 variant on HuggingFace should work. Custom joint encoders are also supported (see [Extensibility](#extensibility)).

---

## Training Concepts

These concepts apply across all three retrieval modes.

### Loss functions

| Loss | Config value | Formula | Key param | When to use |
|------|-------------|---------|-----------|-------------|
| **Triplet Margin** | `triplet` | `relu(d(q,p) - d(q,n) + margin)` | `margin: 0.2` | Good default. Works with small batches and random negatives. |
| **InfoNCE** | `infonce` | Cross-entropy with in-batch negatives | `temperature: 0.05` | Best with larger batches and hard negatives. Typically strongest. |
| **Contrastive** | `contrastive` | `-cos(q,p) + cos(q,n)` | (none) | Simple baseline. No hyperparameters beyond LR. |

Custom loss functions are supported via the Python API — any `(query_emb, pos_emb, neg_emb) -> scalar` callable works.

### Negative mining strategies

Retrieval fine-tuning requires triplets: (query, relevant item, non-relevant item). The non-relevant item is the "negative." How you choose negatives has a big impact on what the model learns.

#### Random negatives (`negatives: random`)

Randomly sample non-relevant items from the corpus. Fast (no model encoding needed), and sufficient for initial training where the model needs to learn basic relevance signals.

```yaml
data:
  negatives: random
  n_negatives: 3       # 3 random negatives per (query, positive) pair
```

#### Hard negatives (`negatives: hard`)

Encode the entire corpus and all queries with the current model, then for each query pick the **most similar non-relevant items** as negatives. These are items the model currently thinks are relevant but aren't — forcing the model to learn finer distinctions.

How it works:
1. Encode all corpus items and queries into embeddings
2. For each query, rank corpus items by cosine similarity
3. From the top-`top_k` results, filter out actually-relevant items
4. Optionally skip the top N (`skip_top`) — see below
5. Pick `n_negatives` from the remaining as hard negatives

```yaml
data:
  negatives: hard
  n_negatives: 3       # 3 hard negatives per (query, positive) pair
  top_k: 50            # consider top-50 most similar corpus items
  skip_top: 0          # how many to skip (see below)
```

#### Mixed negatives (`negatives: mixed`)

Combines random and hard negatives in the same training set. Random negatives teach basic "this is clearly irrelevant" discrimination. Hard negatives push fine-grained ranking — "these two items look similar but only one is correct." This usually gives the best training signal.

```yaml
data:
  negatives: mixed
  n_random: 2          # 2 random negatives per pair
  n_hard: 1            # 1 hard negative per pair
  top_k: 50
```

Note: `n_negatives` is used by `random` and `hard` modes. `n_random` and `n_hard` are used by `mixed` mode. They are separate parameters because mixed mode needs counts for each type.

#### `top_k` — mining search window

When mining hard negatives, `top_k` controls how many top-ranked corpus items to consider. A larger `top_k` searches deeper but takes longer. Typical value: 50.

If `top_k` is too small, you may not find enough non-relevant items (especially for queries where many top results are relevant). If it's too large, the "hard" negatives become easy (they're far down the ranking).

#### `skip_top` — avoiding false negatives

Most retrieval datasets have **incomplete relevance judgments** (qrels). A document might be perfectly relevant to a query but isn't labeled as such, simply because a human annotator didn't see it. These unlabeled positives tend to cluster at the very top of the model's ranking — they look relevant because they *are* relevant.

If you mine these as "hard negatives," you're training the model to push away items that are actually good matches. This hurts performance.

`skip_top` skips the top N non-relevant results before picking hard negatives:

```yaml
data:
  skip_top: 5          # skip the 5 most similar non-relevant items
  top_k: 50            # then pick from ranks 6-50
```

**When to use it:**
- Datasets with sparse qrels (few labeled positives per query): `skip_top: 5-10`
- Datasets with comprehensive qrels: `skip_top: 0` is fine
- When in doubt, `skip_top: 5` is a safe default for hard/mixed negatives

#### `mining_rounds` — iterative re-mining

A single round of hard negative mining uses the **pretrained model** to find hard negatives. But after training, the model has improved — what was "hard" before may now be easy. Iterative mining repeats the mine-train cycle:

```
Round 1: pretrained model → mine negatives → train → adapter_r1
Round 2: fine-tuned model (adapter_r1) → re-mine harder negatives → train → adapter_r2 (final)
```

Each round halves the learning rate to avoid overshooting as negatives get harder.

```yaml
data:
  negatives: mixed      # only meaningful for hard/mixed (random doesn't use mining)
  mining_rounds: 2      # 2 rounds of mine → train
```

**When to use it:**
- 1 round is usually sufficient for most tasks
- 2 rounds helps when the pretrained model is already reasonable on your domain and you need to push further
- 3+ rounds has diminishing returns and risk of overfitting to hard negatives

#### Choosing a strategy

| Situation | Recommended |
|-----------|------------|
| First experiment / quick iteration | `random` with `n_negatives: 1-3` |
| Production training | `mixed` with `n_random: 2, n_hard: 1` |
| Squeezing last bits of performance | `mixed` with `mining_rounds: 2, skip_top: 5` |
| Very large corpus (>1M items) | `random` first, then `hard` on a `corpus_size` subset |

### LoRA vs full fine-tuning

**LoRA (default)**: Only adapter weights are trained and saved (~few MB). Base model weights are frozen.

```yaml
lora:
  r: 8        # rank (4, 8, 16, 32 — higher = more capacity)
  alpha: 16   # scaling factor (convention: 2 * r)
  dropout: 0.1
```

**Full fine-tuning**: All parameters are trained and saved (hundreds of MB). Use a lower learning rate.

```yaml
lora: null
train:
  lr: 1e-5   # lower LR to avoid catastrophic forgetting
```

### Model precision

Two independent controls:

| Setting | What it does | Values |
|---------|-------------|--------|
| `model.dtype` | Precision of base model weights in memory | `null` (fp32), `"fp16"`, `"bf16"` |
| `train.mixed_precision` | AMP during forward/backward pass | `null` (fp32), `"fp16"`, `"bf16"` |

Use both together for maximum memory savings:

```yaml
model:
  dtype: bf16
train:
  mixed_precision: bf16
```

### Evaluation metrics

All implemented from scratch (no external IR evaluation libraries).

| Metric | Description |
|--------|-------------|
| **nDCG@k** | Normalized Discounted Cumulative Gain. Measures ranking quality with graded relevance. |
| **MRR@k** | Mean Reciprocal Rank. 1 / position of the first relevant result. |
| **Recall@k** | Fraction of all relevant documents found in top-k. |

### Output structure

```
output_dir/
  config.yaml              # saved config for reproducibility
  train_history.json       # per-step loss, LR, grad norms, per-epoch loss
  adapter/                 # final LoRA adapter weights
    adapter_model.safetensors
    adapter_config.json
  adapter_r1/              # round 1 adapter (only when mining_rounds > 1)
  baseline.json            # baseline eval metrics (if run_before: true)
  finetuned.json           # fine-tuned eval metrics (if run_after: true)
```

### Result objects

**`RunResult`** (returned by `run()`, `run_multimodal()`, `run_composed()`):

| Field | Type | Description |
|-------|------|-------------|
| `history` | `TrainHistory` | `step_loss`, `step_lr`, `step_grad_norm`, `epoch_loss` |
| `baseline` | `EvalResult \| None` | Baseline metrics (None if `run_before: false`) |
| `finetuned` | `EvalResult \| None` | Fine-tuned metrics (None if `run_after: false`) |
| `adapter_dir` | `str \| None` | Path to saved LoRA adapter |

**`EvalResult`**:

| Field | Type |
|-------|------|
| `metrics` | `dict[str, float]` — e.g. `{"ndcg@5": 0.42, "mrr@5": 0.51}` |
| `model_name` | `str` |
| `dataset_name` | `str` |
| `num_queries` | `int` |
| `num_corpus` | `int` |

Methods: `print()`, `save(path)`, `to_dict()`.

---

## Extensibility

### Custom models (non-HuggingFace)

Every mode supports custom PyTorch models. The pattern is the same across all three: you provide an `nn.Module` (which holds the trainable parameters) and one or more **encode functions** (which define how inputs become embeddings). khoji calls your encode functions during training with gradients enabled, handles L2 normalization, and applies LoRA/optimizer/scheduler around your module.

The key difference between modes is **what your encode functions receive and return**:

| Mode | Encode functions | Input | Output |
|------|-----------------|-------|--------|
| Text → Text | Wired automatically from model + tokenizer + pooling mode | — | — |
| Text → Image | `encode_text_fn` and `encode_image_fn` | `list[str]` (texts) and `list[str]` (image file paths/URLs) | `Tensor (batch, dim)` each |
| Composed | `encode_fn` | `(list[PIL.Image] \| None, list[str] \| None)` | `Tensor (batch, dim)` |

Note the difference: Text → Image `encode_image_fn` receives **file paths** (the trainer handles loading), while Composed `encode_fn` receives **PIL images** (the trainer loads images before calling your function).

#### Text → Text

Your model must follow the HuggingFace convention: `forward(input_ids, attention_mask, ...)` returns an object with a `.last_hidden_state` attribute of shape `(batch, seq_len, hidden_dim)`. khoji applies pooling (CLS, mean, max, etc.) on top.

Your tokenizer must support `tokenizer(texts, padding=True, truncation=True, max_length=N, return_tensors="pt")`.

```python
from khoji import EmbeddingModel, Trainer, TrainingConfig

# For inference / evaluation
embedding_model = EmbeddingModel(
    model=my_encoder,           # nn.Module
    tokenizer=my_tokenizer,     # HuggingFace-compatible tokenizer
    pooling="mean",             # "cls", "mean", "max", "weightedmean", "lasttoken"
)
embeddings = embedding_model.encode(["hello world"])

# For training
trainer = Trainer(
    model=my_encoder,
    tokenizer=my_tokenizer,
    pooling="mean",
    config=TrainingConfig(
        epochs=3,
        lora=None,              # full fine-tuning (LoRA also works if your model has attention layers)
    ),
)
```

#### Text → Image

You provide two encode functions. Both receive strings — `encode_text_fn` gets query texts, `encode_image_fn` gets image source paths/URLs (the trainer calls `load_image()` for you within `encode_image_fn` if needed, or you handle loading yourself).

The `model` parameter should be the `nn.Module` that holds all trainable parameters. Both encode functions should operate on `self.model` (or capture it in a closure) so that gradients flow through to the optimizer.

```python
from khoji import MultimodalTrainer, MultimodalTrainingConfig

trainer = MultimodalTrainer(
    model=my_clip_model,          # nn.Module holding all parameters
    encode_text_fn=my_text_fn,    # (list[str]) -> Tensor (batch, dim)
    encode_image_fn=my_image_fn,  # (list[str]) -> Tensor (batch, dim)  ← receives file paths
    config=MultimodalTrainingConfig(
        epochs=3,
        lora=None,
        base_dir="./my_images",   # base directory for resolving relative image paths
    ),
)
```

#### Composed (mixed-mode)

You provide a single encode function that handles image-only, text-only, or joint mode based on which inputs are non-None. The trainer loads images before calling your function.

```python
from khoji import ComposedTrainer, ComposedTrainingConfig, JointEmbeddingModel

# For training
trainer = ComposedTrainer(
    model=my_model,          # nn.Module holding all parameters
    encode_fn=my_encode_fn,  # (list[PIL.Image]|None, list[str]|None) -> Tensor (batch, dim)
    config=ComposedTrainingConfig(
        epochs=3,
        lora=None,
        base_dir="./my_images",
    ),
)

# For inference / evaluation
model = JointEmbeddingModel(
    encoder=my_encoder_fn,  # (images: list[PIL]|None, texts: list[str]|None, device) -> Tensor
)
# The encoder must handle three calling patterns:
#   encoder(images=[...], texts=None, device)      → image-only embeddings
#   encoder(images=None, texts=[...], device)       → text-only embeddings
#   encoder(images=[...], texts=[...], device)      → joint (image+text) embeddings
```

#### LoRA with custom models

LoRA works with custom models as long as your `nn.Module` contains standard attention layers (Linear modules named `query`, `key`, `value`, `q_proj`, etc.). If your module uses non-standard names, specify them explicitly:

```python
config = TrainingConfig(
    lora=LoRASettings(r=8, alpha=16, target_modules=["my_attn_q", "my_attn_k", "my_attn_v"]),
)
```

If your model doesn't have attention layers suitable for LoRA, use `lora=None` for full fine-tuning.

### Custom loss functions

Pass any callable to `TrainingConfig.loss_fn` (Python API only):

```python
import torch

def circle_loss(query_emb, positive_emb, negative_emb, margin=0.25, gamma=64):
    pos_sim = torch.nn.functional.cosine_similarity(query_emb, positive_emb)
    neg_sim = torch.nn.functional.cosine_similarity(query_emb, negative_emb)
    alpha_p = torch.clamp(1 + margin - pos_sim, min=0)
    alpha_n = torch.clamp(neg_sim + margin, min=0)
    logit_p = -gamma * alpha_p * (pos_sim - (1 - margin))
    logit_n = gamma * alpha_n * (neg_sim - margin)
    return torch.nn.functional.softplus(logit_n - logit_p).mean()

config = TrainingConfig(loss_fn=circle_loss, ...)
```

### Custom metrics

Pass `extra_metrics` to any `Evaluator.evaluate()`:

```python
def precision_at_k(ranked_doc_ids, qrel, k):
    relevant = {d for d, s in qrel.items() if s > 0}
    return sum(1 for d in ranked_doc_ids[:k] if d in relevant) / k

result = evaluator.evaluate(
    dataset=my_dataset,
    k_values=[1, 5, 10],
    extra_metrics={"precision": precision_at_k},
)
# result.metrics includes both built-in and custom metrics
```

The built-in metric functions are also exported for standalone use:

```python
from khoji import ndcg_at_k, mrr_at_k, recall_at_k

ranked = ["d3", "d1", "d5", "d2"]
qrel = {"d1": 2, "d5": 1}
print(recall_at_k(ranked, qrel, k=3))  # 1.0
```

### Custom image preprocessing (Text → Image only)

Three tiers, from most to least automatic:

1. **Auto (default)**: Loads `AutoProcessor` from HuggingFace.
2. **YAML overrides**: Override specific values (`image_size`, `mean`, `std`).
3. **Custom callable** (Python API): Full control over augmentations and transforms.

```python
import torch, torchvision.transforms as T
from PIL import Image

transform = T.Compose([T.Resize(224), T.CenterCrop(224), T.ToTensor(), T.Normalize([0.5]*3, [0.5]*3)])

def my_preprocessor(images: list[Image.Image]) -> torch.Tensor:
    return torch.stack([transform(img) for img in images])

trainer = MultimodalTrainer(
    "openai/clip-vit-base-patch32",
    preprocess_overrides={"custom_fn": my_preprocessor},
    config=MultimodalTrainingConfig(...),
)
```

---

## Architecture

```
src/khoji/
  # ── Text → Text ──────────────────────────────
  config.py                 ForgeConfig (YAML)
  run.py                    run() orchestrator
  dataset.py                load_beir, load_custom, RetrievalDataset
  data.py                   Triplet, TripletDataset, negative mining
  model.py                  EmbeddingModel (pooling auto-detection)
  trainer.py                Trainer, TrainingConfig, TrainHistory
  evaluator.py              Evaluator, EvalResult

  # ── Text → Image ─────────────────────────────
  multimodal_config.py      MultimodalForgeConfig
  multimodal_run.py         run_multimodal()
  multimodal_dataset.py     load_flickr30k, load_rsicd, load_custom_multimodal
  multimodal_data.py        MultimodalTriplet, negative mining
  multimodal_model.py       MultimodalEmbeddingModel, JointEmbeddingModel
  multimodal_trainer.py     MultimodalTrainer
  multimodal_evaluator.py   MultimodalEvaluator

  # ── Composed (mixed-mode) ────────────────────
  composed_config.py        ComposedForgeConfig
  composed_run.py           run_composed()
  composed_dataset.py       load_custom_composed, ComposedRetrievalDataset
  composed_data.py          ComposedTriplet, negative mining
  composed_trainer.py       ComposedTrainer
  composed_evaluator.py     ComposedEvaluator

  # ── Shared ────────────────────────────────────
  loss.py                   triplet_margin_loss, infonce_loss, contrastive_loss
  metrics.py                ndcg_at_k, mrr_at_k, recall_at_k
  lora.py                   LoRASettings, apply_lora
  image_utils.py            load_image, load_images_batch, build_image_processor
  device.py                 get_device (CUDA > MPS > CPU)
```

### Data flow (all modes follow the same pattern)

```
Config (YAML or Python)
  │
  ├─ Dataset loading ──> queries + corpus + qrels
  │
  ├─ Baseline evaluation (optional)
  │
  └─ Mining round loop:
       │
       ├─ Build triplets (random / hard / mixed)
       │    (round 2+ uses fine-tuned model for mining)
       │
       ├─ Trainer.train() ──> TrainHistory + adapter
       │
       └─ adapter feeds next round
  │
  ├─ Fine-tuned evaluation (optional)
  │
  └─ RunResult (history + baseline + finetuned + adapter_dir)
```

---

## Example Scripts

| Script | Mode | Description |
|--------|------|-------------|
| `scripts/train_text_retrieval.py` | Text → Text | Config-driven + manual API on FiQA |
| `scripts/train_multimodal_retrieval.py` | Text → Image | Config-driven + manual API on RSICD |
| `scripts/train_composed_retrieval.py` | Composed | Standalone FashionIQ training (low-level) |
| `scripts/train_composed_retrieval_api.py` | Composed | Config-driven + manual API on FashionIQ |
| `scripts/train_sku_matching.py` | Mixed-mode | Cross-brand SKU matching (img+txt → img+txt) on AI-generated grocery data |
| `scripts/fashioniq/download_data.py` | Data setup | Download FashionIQ annotations (required for composed scripts) |

#### FashionIQ (composed retrieval)

```bash
python scripts/fashioniq/download_data.py
python scripts/train_composed_retrieval_api.py --approach api
```

#### SKU Matching (mixed-mode: img+txt → img+txt)

Cross-brand product matching using AI-generated grocery product images. One model trained, evaluated on two generalization dimensions (unseen brand + unseen product families). See [`data/sku-matching/README.md`](data/sku-matching/README.md) for dataset details and [`notebooks/05_sku_matching_mixed_mode.ipynb`](notebooks/05_sku_matching_mixed_mode.ipynb) for a full walkthrough with visualizations.

```bash
python scripts/train_sku_matching.py                                          # default: mixed negatives, 2 rounds
python scripts/train_sku_matching.py --negatives mixed --mining-rounds 2      # explicit
python scripts/train_sku_matching.py --epochs 10 --batch-size 64             # tune hyperparams
```

---

## Example Configs

Located in `configs/`:

| Config | Mode | Description |
|--------|------|-------------|
| `minilm_scifact_full.yaml` | Text → Text | MiniLM on SciFact. Full training + evaluation. |
| `minilm_scifact_overfit.yaml` | Text → Text | Overfit on 1 batch. Pipeline debugging. |
| `clip_rsicd_full.yaml` | Text → Image | CLIP ViT-B/32 on RSICD satellite imagery. Full training + evaluation. |
| `clip_rsicd_overfit.yaml` | Text → Image | Overfit on 1 batch. Pipeline debugging. |

---

## Hardware

Auto-detected: CUDA (1st) > MPS (2nd) > CPU (3rd).

**MPS tip**: If you hit OOM, reduce `batch_size` and increase `grad_accum_steps` to maintain the same effective batch size.

---

## Development

### Running tests

```bash
uv run pytest tests/ -v    # 132 tests
```

### Test coverage

| Module | Tests |
|--------|-------|
| `metrics.py` | nDCG, MRR, Recall — edge cases, graded relevance, k cutoffs |
| `model.py` | All pooling modes, auto-detection, L2 normalization |
| `data.py` | Random/hard/mixed negatives, determinism, correctness |
| `loss.py` | All 3 losses — shapes, values, gradient flow |
| `config.py` | YAML roundtrip, type coercion, defaults |
| `lora.py` | apply_lora, auto-detection, custom targets |
| `evaluator.py` | Custom datasets, extra metrics, serialization |
| `trainer.py` | Training loop, history tracking |
| `dataset.py` | load_custom, missing files, RetrievalDataset |
| `multimodal` | CLIP encoding, config, datasets, training, LoRA targeting, evaluation |
| `composed` | Dataset format, triplets, config YAML, custom model training, evaluation |
| `integration` | BEIR loading, retrieval sanity checks |

### Linting

```bash
uv run ruff check src/ tests/
```

---

## Roadmap

- [x] Text → Text retrieval (BERT, BGE, sentence-transformers)
- [x] Text → Image retrieval (CLIP, SigLIP)
- [x] Composed mixed-mode retrieval (BLIP-2) — image, text, or both on query and target sides
- [x] Full fine-tuning (`lora: null`)
- [x] Custom models, loss functions, metrics
- [ ] Validation loss tracking during training
- [ ] Early stopping
- [ ] Distributed training (multi-GPU via DDP)
- [ ] Checkpoint resumption
- [ ] Adapter merging (LoRA → base model)
- [ ] Logging integration (W&B, TensorBoard)

---

## License

MIT
