Coverage for llm_dataset_engine/stages/response_parser_stage.py: 62%
102 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""Response parsing stage for structured output extraction."""
3import json
4import re
5from abc import ABC, abstractmethod
6from decimal import Decimal
7from typing import Any, Dict, List, Optional, Type
9import pandas as pd
10from pydantic import BaseModel, ValidationError
12from llm_dataset_engine.core.models import (
13 CostEstimate,
14 ResponseBatch,
15 ValidationResult,
16)
17from llm_dataset_engine.stages.pipeline_stage import PipelineStage
20class ResponseParser(ABC):
21 """Abstract base for response parsers (Strategy pattern)."""
23 @abstractmethod
24 def parse(self, response: str) -> Dict[str, Any]:
25 """Parse response into structured data."""
26 pass
29class RawTextParser(ResponseParser):
30 """Parser that returns raw text."""
32 def parse(self, response: str) -> Dict[str, Any]:
33 """Return response as-is."""
34 return {"output": response.strip()}
37class JSONParser(ResponseParser):
38 """Parser that extracts JSON from response."""
40 def __init__(self, strict: bool = False):
41 """
42 Initialize JSON parser.
44 Args:
45 strict: If True, fail on invalid JSON; if False, try to extract
46 """
47 self.strict = strict
49 def parse(self, response: str) -> Dict[str, Any]:
50 """Parse JSON from response."""
51 try:
52 return json.loads(response.strip())
53 except json.JSONDecodeError:
54 if self.strict:
55 raise
57 # Try to extract JSON from markdown code blocks
58 if "```json" in response:
59 start = response.find("```json") + 7
60 end = response.find("```", start)
61 json_str = response[start:end].strip()
62 return json.loads(json_str)
63 elif "```" in response:
64 start = response.find("```") + 3
65 end = response.find("```", start)
66 json_str = response[start:end].strip()
67 return json.loads(json_str)
68 else:
69 # Return as raw text if can't parse
70 return {"output": response.strip()}
73class PydanticParser(ResponseParser):
74 """
75 Parser that validates responses against Pydantic models.
77 Provides type-safe extraction with automatic validation.
78 """
80 def __init__(
81 self, model: Type[BaseModel], strict: bool = True
82 ):
83 """
84 Initialize Pydantic parser.
86 Args:
87 model: Pydantic model class for validation
88 strict: If True, fail on validation errors
89 """
90 self.model = model
91 self.strict = strict
93 def parse(self, response: str) -> Dict[str, Any]:
94 """Parse and validate response with Pydantic model."""
95 try:
96 # Try to parse as JSON first
97 json_parser = JSONParser(strict=False)
98 data = json_parser.parse(response)
100 # Validate with Pydantic
101 validated = self.model(**data)
102 return validated.model_dump()
104 except ValidationError as e:
105 if self.strict:
106 raise ValueError(f"Pydantic validation failed: {e}")
107 else:
108 # Return raw data if validation fails
109 return {"output": response.strip(), "validation_error": str(e)}
112class RegexParser(ResponseParser):
113 """
114 Parser that extracts data using regex patterns.
116 Useful for extracting specific fields from structured text.
117 """
119 def __init__(self, patterns: Dict[str, str]):
120 """
121 Initialize regex parser.
123 Args:
124 patterns: Dict mapping field names to regex patterns
125 """
126 self.patterns = {
127 key: re.compile(pattern)
128 for key, pattern in patterns.items()
129 }
131 def parse(self, response: str) -> Dict[str, Any]:
132 """Extract fields using regex patterns."""
133 result = {}
135 for field_name, pattern in self.patterns.items():
136 match = pattern.search(response)
137 if match:
138 # Use first group if groups exist, else full match
139 if match.groups():
140 result[field_name] = match.group(1)
141 else:
142 result[field_name] = match.group(0)
143 else:
144 result[field_name] = None
146 return result
149class ResponseParserStage(
150 PipelineStage[
151 tuple[List[ResponseBatch], List[str]], pd.DataFrame
152 ]
153):
154 """
155 Parse LLM responses into structured DataFrame.
157 Responsibilities:
158 - Parse responses using configured parser
159 - Map parsed data to output columns
160 - Handle parse errors gracefully
161 - Return DataFrame with results
162 """
164 def __init__(
165 self,
166 parser: ResponseParser | None = None,
167 output_columns: List[str] | None = None,
168 ):
169 """
170 Initialize response parser stage.
172 Args:
173 parser: Response parser (default: RawTextParser)
174 output_columns: Output column names
175 """
176 super().__init__("ResponseParser")
177 self.parser = parser or RawTextParser()
178 self.output_columns = output_columns or ["output"]
180 def process(
181 self,
182 input_data: tuple[List[ResponseBatch], List[str]],
183 context: Any,
184 ) -> pd.DataFrame:
185 """Parse responses into DataFrame."""
186 batches, output_cols = input_data
188 # Initialize result storage
189 results: Dict[int, Dict[str, Any]] = {}
191 # Parse all responses
192 for batch in batches:
193 for response, metadata in zip(
194 batch.responses, batch.metadata
195 ):
196 try:
197 # Parse response
198 parsed = self.parser.parse(response)
200 # Map to output columns
201 row_data = {}
202 if len(output_cols) == 1:
203 # Single output column
204 if isinstance(parsed, dict) and "output" in parsed:
205 row_data[output_cols[0]] = parsed["output"]
206 elif isinstance(parsed, dict):
207 # Use first value
208 row_data[output_cols[0]] = next(
209 iter(parsed.values())
210 )
211 else:
212 row_data[output_cols[0]] = parsed
213 else:
214 # Multiple output columns
215 for col in output_cols:
216 row_data[col] = parsed.get(col, None)
218 results[metadata.row_index] = row_data
220 except Exception as e:
221 self.logger.error(
222 f"Failed to parse response at row "
223 f"{metadata.row_index}: {e}"
224 )
225 # Store None for failed parses
226 results[metadata.row_index] = {
227 col: None for col in output_cols
228 }
230 # Create DataFrame
231 df = pd.DataFrame.from_dict(results, orient="index")
232 df.index.name = "row_index"
234 self.logger.info(f"Parsed {len(results)} responses")
236 return df
238 def validate_input(
239 self, input_data: tuple[List[ResponseBatch], List[str]]
240 ) -> ValidationResult:
241 """Validate response batches."""
242 result = ValidationResult(is_valid=True)
244 batches, output_cols = input_data
246 if not batches:
247 result.add_error("No response batches provided")
249 if not output_cols:
250 result.add_error("No output columns specified")
252 return result
254 def estimate_cost(
255 self, input_data: tuple[List[ResponseBatch], List[str]]
256 ) -> CostEstimate:
257 """Response parsing has no LLM cost."""
258 return CostEstimate(
259 total_cost=Decimal("0.0"),
260 total_tokens=0,
261 input_tokens=0,
262 output_tokens=0,
263 rows=sum(len(b.responses) for b in input_data[0]),
264 )