Coverage for llm_dataset_engine/core/models.py: 92%

107 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2Core data models for execution results and metadata. 

3 

4These models represent the outputs and state information from pipeline 

5execution, following clean code principles with type safety. 

6""" 

7 

8from dataclasses import dataclass, field 

9from datetime import datetime 

10from decimal import Decimal 

11from typing import Any, Dict, List, Optional 

12from uuid import UUID, uuid4 

13 

14import pandas as pd 

15 

16 

17@dataclass 

18class LLMResponse: 

19 """Response from a single LLM invocation.""" 

20 

21 text: str 

22 tokens_in: int 

23 tokens_out: int 

24 model: str 

25 cost: Decimal 

26 latency_ms: float 

27 metadata: Dict[str, Any] = field(default_factory=dict) 

28 

29 

30@dataclass 

31class CostEstimate: 

32 """Cost estimation for pipeline execution.""" 

33 

34 total_cost: Decimal 

35 total_tokens: int 

36 input_tokens: int 

37 output_tokens: int 

38 rows: int 

39 breakdown_by_stage: Dict[str, Decimal] = field(default_factory=dict) 

40 confidence: str = "estimate" # estimate, sample-based, actual 

41 

42 

43@dataclass 

44class ProcessingStats: 

45 """Statistics from pipeline execution.""" 

46 

47 total_rows: int 

48 processed_rows: int 

49 failed_rows: int 

50 skipped_rows: int 

51 rows_per_second: float 

52 total_duration_seconds: float 

53 stage_durations: Dict[str, float] = field(default_factory=dict) 

54 

55 

56@dataclass 

57class ErrorInfo: 

58 """Information about an error during processing.""" 

59 

60 row_index: int 

61 stage_name: str 

62 error_type: str 

63 error_message: str 

64 timestamp: datetime 

65 context: Dict[str, Any] = field(default_factory=dict) 

66 

67 

68@dataclass 

69class ExecutionResult: 

70 """Complete result from pipeline execution.""" 

71 

72 data: pd.DataFrame 

73 metrics: ProcessingStats 

74 costs: CostEstimate 

75 errors: List[ErrorInfo] = field(default_factory=list) 

76 execution_id: UUID = field(default_factory=uuid4) 

77 start_time: datetime = field(default_factory=datetime.now) 

78 end_time: Optional[datetime] = None 

79 success: bool = True 

80 metadata: Dict[str, Any] = field(default_factory=dict) 

81 

82 @property 

83 def duration(self) -> float: 

84 """Get execution duration in seconds.""" 

85 if self.end_time is None: 

86 return 0.0 

87 return (self.end_time - self.start_time).total_seconds() 

88 

89 @property 

90 def error_rate(self) -> float: 

91 """Get error rate as percentage.""" 

92 if self.metrics.total_rows == 0: 

93 return 0.0 

94 return ( 

95 self.metrics.failed_rows / self.metrics.total_rows 

96 ) * 100 

97 

98 

99@dataclass 

100class ValidationResult: 

101 """Result from validation checks.""" 

102 

103 is_valid: bool 

104 errors: List[str] = field(default_factory=list) 

105 warnings: List[str] = field(default_factory=list) 

106 

107 def add_error(self, error: str) -> None: 

108 """Add an error message.""" 

109 self.errors.append(error) 

110 self.is_valid = False 

111 

112 def add_warning(self, warning: str) -> None: 

113 """Add a warning message.""" 

114 self.warnings.append(warning) 

115 

116 

117@dataclass 

118class WriteConfirmation: 

119 """Confirmation of successful data write.""" 

120 

121 path: str 

122 rows_written: int 

123 success: bool 

124 timestamp: datetime = field(default_factory=datetime.now) 

125 metadata: Dict[str, Any] = field(default_factory=dict) 

126 

127 

128@dataclass 

129class CheckpointInfo: 

130 """Information about a checkpoint.""" 

131 

132 session_id: UUID 

133 checkpoint_path: str 

134 row_index: int 

135 stage_index: int 

136 timestamp: datetime 

137 size_bytes: int 

138 

139 

140@dataclass 

141class RowMetadata: 

142 """Metadata for a single row during processing.""" 

143 

144 row_index: int 

145 row_id: Optional[Any] = None 

146 batch_id: Optional[int] = None 

147 attempt: int = 1 

148 custom: Dict[str, Any] = field(default_factory=dict) 

149 

150 

151@dataclass 

152class PromptBatch: 

153 """Batch of prompts for processing.""" 

154 

155 prompts: List[str] 

156 metadata: List[RowMetadata] 

157 batch_id: int 

158 

159 

160@dataclass 

161class ResponseBatch: 

162 """Batch of responses from LLM.""" 

163 

164 responses: List[str] 

165 metadata: List[RowMetadata] 

166 tokens_used: int 

167 cost: Decimal 

168 batch_id: int 

169 latencies_ms: List[float] = field(default_factory=list) 

170