Coverage for llm_dataset_engine/stages/prompt_formatter_stage.py: 19%

69 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1"""Prompt formatting stage for template-based prompt generation.""" 

2 

3from decimal import Decimal 

4from typing import Any, Dict, List 

5 

6import pandas as pd 

7from jinja2 import Template as Jinja2Template 

8 

9from llm_dataset_engine.core.models import ( 

10 CostEstimate, 

11 PromptBatch, 

12 RowMetadata, 

13 ValidationResult, 

14) 

15from llm_dataset_engine.core.specifications import PromptSpec 

16from llm_dataset_engine.stages.pipeline_stage import PipelineStage 

17 

18 

19class PromptFormatterStage( 

20 PipelineStage[tuple[pd.DataFrame, PromptSpec], List[PromptBatch]] 

21): 

22 """ 

23 Format prompts using template and row data. 

24  

25 Responsibilities: 

26 - Extract input columns from rows 

27 - Format prompts using template 

28 - Batch prompts for efficient processing 

29 - Attach metadata for tracking 

30 """ 

31 

32 def __init__( 

33 self, batch_size: int = 100, use_jinja2: bool = False 

34 ): 

35 """ 

36 Initialize prompt formatter stage. 

37 

38 Args: 

39 batch_size: Number of prompts per batch 

40 use_jinja2: Use Jinja2 for template rendering 

41 """ 

42 super().__init__("PromptFormatter") 

43 self.batch_size = batch_size 

44 self.use_jinja2 = use_jinja2 

45 

46 def process( 

47 self, input_data: tuple[pd.DataFrame, PromptSpec], context: Any 

48 ) -> List[PromptBatch]: 

49 """Format prompts from DataFrame rows.""" 

50 df, prompt_spec = input_data 

51 

52 prompts: List[str] = [] 

53 metadata_list: List[RowMetadata] = [] 

54 

55 # Extract template variables 

56 template_str = prompt_spec.template 

57 

58 # Create template renderer 

59 if self.use_jinja2: 

60 template = Jinja2Template(template_str) 

61 

62 # Format prompt for each row 

63 for idx, row in df.iterrows(): 

64 try: 

65 # Extract input columns 

66 row_data = { 

67 col: row[col] 

68 for col in df.columns 

69 if col in template_str 

70 } 

71 

72 # Format prompt (Jinja2 or f-string) 

73 if self.use_jinja2: 

74 prompt = template.render(**row_data) 

75 else: 

76 prompt = template_str.format(**row_data) 

77 

78 # Add few-shot examples if specified 

79 if prompt_spec.few_shot_examples: 

80 examples_text = self._format_few_shot_examples( 

81 prompt_spec.few_shot_examples 

82 ) 

83 prompt = f"{examples_text}\n\n{prompt}" 

84 

85 # Add system message if specified 

86 if prompt_spec.system_message: 

87 prompt = f"{prompt_spec.system_message}\n\n{prompt}" 

88 

89 prompts.append(prompt) 

90 

91 # Create metadata 

92 metadata = RowMetadata( 

93 row_index=idx, 

94 row_id=row.get("id", None), 

95 ) 

96 metadata_list.append(metadata) 

97 

98 except KeyError as e: 

99 self.logger.warning( 

100 f"Missing template variable at row {idx}: {e}" 

101 ) 

102 continue 

103 except Exception as e: 

104 self.logger.error( 

105 f"Error formatting prompt at row {idx}: {e}" 

106 ) 

107 continue 

108 

109 # Create batches 

110 batches: List[PromptBatch] = [] 

111 for i in range(0, len(prompts), self.batch_size): 

112 batch_prompts = prompts[i : i + self.batch_size] 

113 batch_metadata = metadata_list[i : i + self.batch_size] 

114 

115 batch = PromptBatch( 

116 prompts=batch_prompts, 

117 metadata=batch_metadata, 

118 batch_id=i // self.batch_size, 

119 ) 

120 batches.append(batch) 

121 

122 self.logger.info( 

123 f"Formatted {len(prompts)} prompts into {len(batches)} batches" 

124 ) 

125 

126 return batches 

127 

128 def _format_few_shot_examples( 

129 self, examples: List[Dict[str, str]] 

130 ) -> str: 

131 """ 

132 Format few-shot examples for prompt. 

133 

134 Args: 

135 examples: List of example dicts with 'input' and 'output' 

136 

137 Returns: 

138 Formatted examples text 

139 """ 

140 formatted = ["Here are some examples:\n"] 

141 

142 for i, example in enumerate(examples, 1): 

143 formatted.append(f"Example {i}:") 

144 formatted.append(f"Input: {example.get('input', '')}") 

145 formatted.append(f"Output: {example.get('output', '')}") 

146 formatted.append("") 

147 

148 return "\n".join(formatted) 

149 

150 def validate_input( 

151 self, input_data: tuple[pd.DataFrame, PromptSpec] 

152 ) -> ValidationResult: 

153 """Validate DataFrame and prompt specification.""" 

154 result = ValidationResult(is_valid=True) 

155 

156 df, prompt_spec = input_data 

157 

158 # Check DataFrame not empty 

159 if df.empty: 

160 result.add_error("DataFrame is empty") 

161 

162 # Check template variables exist in DataFrame 

163 template = prompt_spec.template 

164 import re 

165 

166 variables = re.findall(r"\{(\w+)\}", template) 

167 missing_vars = set(variables) - set(df.columns) 

168 

169 if missing_vars: 

170 result.add_error( 

171 f"Template variables not in DataFrame: {missing_vars}" 

172 ) 

173 

174 return result 

175 

176 def estimate_cost( 

177 self, input_data: tuple[pd.DataFrame, PromptSpec] 

178 ) -> CostEstimate: 

179 """Prompt formatting has no LLM cost.""" 

180 return CostEstimate( 

181 total_cost=Decimal("0.0"), 

182 total_tokens=0, 

183 input_tokens=0, 

184 output_tokens=0, 

185 rows=len(input_data[0]), 

186 ) 

187