Coverage for llm_dataset_engine/stages/prompt_formatter_stage.py: 19%
69 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""Prompt formatting stage for template-based prompt generation."""
3from decimal import Decimal
4from typing import Any, Dict, List
6import pandas as pd
7from jinja2 import Template as Jinja2Template
9from llm_dataset_engine.core.models import (
10 CostEstimate,
11 PromptBatch,
12 RowMetadata,
13 ValidationResult,
14)
15from llm_dataset_engine.core.specifications import PromptSpec
16from llm_dataset_engine.stages.pipeline_stage import PipelineStage
19class PromptFormatterStage(
20 PipelineStage[tuple[pd.DataFrame, PromptSpec], List[PromptBatch]]
21):
22 """
23 Format prompts using template and row data.
25 Responsibilities:
26 - Extract input columns from rows
27 - Format prompts using template
28 - Batch prompts for efficient processing
29 - Attach metadata for tracking
30 """
32 def __init__(
33 self, batch_size: int = 100, use_jinja2: bool = False
34 ):
35 """
36 Initialize prompt formatter stage.
38 Args:
39 batch_size: Number of prompts per batch
40 use_jinja2: Use Jinja2 for template rendering
41 """
42 super().__init__("PromptFormatter")
43 self.batch_size = batch_size
44 self.use_jinja2 = use_jinja2
46 def process(
47 self, input_data: tuple[pd.DataFrame, PromptSpec], context: Any
48 ) -> List[PromptBatch]:
49 """Format prompts from DataFrame rows."""
50 df, prompt_spec = input_data
52 prompts: List[str] = []
53 metadata_list: List[RowMetadata] = []
55 # Extract template variables
56 template_str = prompt_spec.template
58 # Create template renderer
59 if self.use_jinja2:
60 template = Jinja2Template(template_str)
62 # Format prompt for each row
63 for idx, row in df.iterrows():
64 try:
65 # Extract input columns
66 row_data = {
67 col: row[col]
68 for col in df.columns
69 if col in template_str
70 }
72 # Format prompt (Jinja2 or f-string)
73 if self.use_jinja2:
74 prompt = template.render(**row_data)
75 else:
76 prompt = template_str.format(**row_data)
78 # Add few-shot examples if specified
79 if prompt_spec.few_shot_examples:
80 examples_text = self._format_few_shot_examples(
81 prompt_spec.few_shot_examples
82 )
83 prompt = f"{examples_text}\n\n{prompt}"
85 # Add system message if specified
86 if prompt_spec.system_message:
87 prompt = f"{prompt_spec.system_message}\n\n{prompt}"
89 prompts.append(prompt)
91 # Create metadata
92 metadata = RowMetadata(
93 row_index=idx,
94 row_id=row.get("id", None),
95 )
96 metadata_list.append(metadata)
98 except KeyError as e:
99 self.logger.warning(
100 f"Missing template variable at row {idx}: {e}"
101 )
102 continue
103 except Exception as e:
104 self.logger.error(
105 f"Error formatting prompt at row {idx}: {e}"
106 )
107 continue
109 # Create batches
110 batches: List[PromptBatch] = []
111 for i in range(0, len(prompts), self.batch_size):
112 batch_prompts = prompts[i : i + self.batch_size]
113 batch_metadata = metadata_list[i : i + self.batch_size]
115 batch = PromptBatch(
116 prompts=batch_prompts,
117 metadata=batch_metadata,
118 batch_id=i // self.batch_size,
119 )
120 batches.append(batch)
122 self.logger.info(
123 f"Formatted {len(prompts)} prompts into {len(batches)} batches"
124 )
126 return batches
128 def _format_few_shot_examples(
129 self, examples: List[Dict[str, str]]
130 ) -> str:
131 """
132 Format few-shot examples for prompt.
134 Args:
135 examples: List of example dicts with 'input' and 'output'
137 Returns:
138 Formatted examples text
139 """
140 formatted = ["Here are some examples:\n"]
142 for i, example in enumerate(examples, 1):
143 formatted.append(f"Example {i}:")
144 formatted.append(f"Input: {example.get('input', '')}")
145 formatted.append(f"Output: {example.get('output', '')}")
146 formatted.append("")
148 return "\n".join(formatted)
150 def validate_input(
151 self, input_data: tuple[pd.DataFrame, PromptSpec]
152 ) -> ValidationResult:
153 """Validate DataFrame and prompt specification."""
154 result = ValidationResult(is_valid=True)
156 df, prompt_spec = input_data
158 # Check DataFrame not empty
159 if df.empty:
160 result.add_error("DataFrame is empty")
162 # Check template variables exist in DataFrame
163 template = prompt_spec.template
164 import re
166 variables = re.findall(r"\{(\w+)\}", template)
167 missing_vars = set(variables) - set(df.columns)
169 if missing_vars:
170 result.add_error(
171 f"Template variables not in DataFrame: {missing_vars}"
172 )
174 return result
176 def estimate_cost(
177 self, input_data: tuple[pd.DataFrame, PromptSpec]
178 ) -> CostEstimate:
179 """Prompt formatting has no LLM cost."""
180 return CostEstimate(
181 total_cost=Decimal("0.0"),
182 total_tokens=0,
183 input_tokens=0,
184 output_tokens=0,
185 rows=len(input_data[0]),
186 )