Coverage for llm_dataset_engine/stages/response_parser_stage.py: 62%

102 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1"""Response parsing stage for structured output extraction.""" 

2 

3import json 

4import re 

5from abc import ABC, abstractmethod 

6from decimal import Decimal 

7from typing import Any, Dict, List, Optional, Type 

8 

9import pandas as pd 

10from pydantic import BaseModel, ValidationError 

11 

12from llm_dataset_engine.core.models import ( 

13 CostEstimate, 

14 ResponseBatch, 

15 ValidationResult, 

16) 

17from llm_dataset_engine.stages.pipeline_stage import PipelineStage 

18 

19 

20class ResponseParser(ABC): 

21 """Abstract base for response parsers (Strategy pattern).""" 

22 

23 @abstractmethod 

24 def parse(self, response: str) -> Dict[str, Any]: 

25 """Parse response into structured data.""" 

26 pass 

27 

28 

29class RawTextParser(ResponseParser): 

30 """Parser that returns raw text.""" 

31 

32 def parse(self, response: str) -> Dict[str, Any]: 

33 """Return response as-is.""" 

34 return {"output": response.strip()} 

35 

36 

37class JSONParser(ResponseParser): 

38 """Parser that extracts JSON from response.""" 

39 

40 def __init__(self, strict: bool = False): 

41 """ 

42 Initialize JSON parser. 

43 

44 Args: 

45 strict: If True, fail on invalid JSON; if False, try to extract 

46 """ 

47 self.strict = strict 

48 

49 def parse(self, response: str) -> Dict[str, Any]: 

50 """Parse JSON from response.""" 

51 try: 

52 return json.loads(response.strip()) 

53 except json.JSONDecodeError: 

54 if self.strict: 

55 raise 

56 

57 # Try to extract JSON from markdown code blocks 

58 if "```json" in response: 

59 start = response.find("```json") + 7 

60 end = response.find("```", start) 

61 json_str = response[start:end].strip() 

62 return json.loads(json_str) 

63 elif "```" in response: 

64 start = response.find("```") + 3 

65 end = response.find("```", start) 

66 json_str = response[start:end].strip() 

67 return json.loads(json_str) 

68 else: 

69 # Return as raw text if can't parse 

70 return {"output": response.strip()} 

71 

72 

73class PydanticParser(ResponseParser): 

74 """ 

75 Parser that validates responses against Pydantic models. 

76  

77 Provides type-safe extraction with automatic validation. 

78 """ 

79 

80 def __init__( 

81 self, model: Type[BaseModel], strict: bool = True 

82 ): 

83 """ 

84 Initialize Pydantic parser. 

85 

86 Args: 

87 model: Pydantic model class for validation 

88 strict: If True, fail on validation errors 

89 """ 

90 self.model = model 

91 self.strict = strict 

92 

93 def parse(self, response: str) -> Dict[str, Any]: 

94 """Parse and validate response with Pydantic model.""" 

95 try: 

96 # Try to parse as JSON first 

97 json_parser = JSONParser(strict=False) 

98 data = json_parser.parse(response) 

99 

100 # Validate with Pydantic 

101 validated = self.model(**data) 

102 return validated.model_dump() 

103 

104 except ValidationError as e: 

105 if self.strict: 

106 raise ValueError(f"Pydantic validation failed: {e}") 

107 else: 

108 # Return raw data if validation fails 

109 return {"output": response.strip(), "validation_error": str(e)} 

110 

111 

112class RegexParser(ResponseParser): 

113 """ 

114 Parser that extracts data using regex patterns. 

115  

116 Useful for extracting specific fields from structured text. 

117 """ 

118 

119 def __init__(self, patterns: Dict[str, str]): 

120 """ 

121 Initialize regex parser. 

122 

123 Args: 

124 patterns: Dict mapping field names to regex patterns 

125 """ 

126 self.patterns = { 

127 key: re.compile(pattern) 

128 for key, pattern in patterns.items() 

129 } 

130 

131 def parse(self, response: str) -> Dict[str, Any]: 

132 """Extract fields using regex patterns.""" 

133 result = {} 

134 

135 for field_name, pattern in self.patterns.items(): 

136 match = pattern.search(response) 

137 if match: 

138 # Use first group if groups exist, else full match 

139 if match.groups(): 

140 result[field_name] = match.group(1) 

141 else: 

142 result[field_name] = match.group(0) 

143 else: 

144 result[field_name] = None 

145 

146 return result 

147 

148 

149class ResponseParserStage( 

150 PipelineStage[ 

151 tuple[List[ResponseBatch], List[str]], pd.DataFrame 

152 ] 

153): 

154 """ 

155 Parse LLM responses into structured DataFrame. 

156  

157 Responsibilities: 

158 - Parse responses using configured parser 

159 - Map parsed data to output columns 

160 - Handle parse errors gracefully 

161 - Return DataFrame with results 

162 """ 

163 

164 def __init__( 

165 self, 

166 parser: ResponseParser | None = None, 

167 output_columns: List[str] | None = None, 

168 ): 

169 """ 

170 Initialize response parser stage. 

171 

172 Args: 

173 parser: Response parser (default: RawTextParser) 

174 output_columns: Output column names 

175 """ 

176 super().__init__("ResponseParser") 

177 self.parser = parser or RawTextParser() 

178 self.output_columns = output_columns or ["output"] 

179 

180 def process( 

181 self, 

182 input_data: tuple[List[ResponseBatch], List[str]], 

183 context: Any, 

184 ) -> pd.DataFrame: 

185 """Parse responses into DataFrame.""" 

186 batches, output_cols = input_data 

187 

188 # Initialize result storage 

189 results: Dict[int, Dict[str, Any]] = {} 

190 

191 # Parse all responses 

192 for batch in batches: 

193 for response, metadata in zip( 

194 batch.responses, batch.metadata 

195 ): 

196 try: 

197 # Parse response 

198 parsed = self.parser.parse(response) 

199 

200 # Map to output columns 

201 row_data = {} 

202 if len(output_cols) == 1: 

203 # Single output column 

204 if isinstance(parsed, dict) and "output" in parsed: 

205 row_data[output_cols[0]] = parsed["output"] 

206 elif isinstance(parsed, dict): 

207 # Use first value 

208 row_data[output_cols[0]] = next( 

209 iter(parsed.values()) 

210 ) 

211 else: 

212 row_data[output_cols[0]] = parsed 

213 else: 

214 # Multiple output columns 

215 for col in output_cols: 

216 row_data[col] = parsed.get(col, None) 

217 

218 results[metadata.row_index] = row_data 

219 

220 except Exception as e: 

221 self.logger.error( 

222 f"Failed to parse response at row " 

223 f"{metadata.row_index}: {e}" 

224 ) 

225 # Store None for failed parses 

226 results[metadata.row_index] = { 

227 col: None for col in output_cols 

228 } 

229 

230 # Create DataFrame 

231 df = pd.DataFrame.from_dict(results, orient="index") 

232 df.index.name = "row_index" 

233 

234 self.logger.info(f"Parsed {len(results)} responses") 

235 

236 return df 

237 

238 def validate_input( 

239 self, input_data: tuple[List[ResponseBatch], List[str]] 

240 ) -> ValidationResult: 

241 """Validate response batches.""" 

242 result = ValidationResult(is_valid=True) 

243 

244 batches, output_cols = input_data 

245 

246 if not batches: 

247 result.add_error("No response batches provided") 

248 

249 if not output_cols: 

250 result.add_error("No output columns specified") 

251 

252 return result 

253 

254 def estimate_cost( 

255 self, input_data: tuple[List[ResponseBatch], List[str]] 

256 ) -> CostEstimate: 

257 """Response parsing has no LLM cost.""" 

258 return CostEstimate( 

259 total_cost=Decimal("0.0"), 

260 total_tokens=0, 

261 input_tokens=0, 

262 output_tokens=0, 

263 rows=sum(len(b.responses) for b in input_data[0]), 

264 ) 

265