Coverage for llm_dataset_engine/core/models.py: 92%
107 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Core data models for execution results and metadata.
4These models represent the outputs and state information from pipeline
5execution, following clean code principles with type safety.
6"""
8from dataclasses import dataclass, field
9from datetime import datetime
10from decimal import Decimal
11from typing import Any, Dict, List, Optional
12from uuid import UUID, uuid4
14import pandas as pd
17@dataclass
18class LLMResponse:
19 """Response from a single LLM invocation."""
21 text: str
22 tokens_in: int
23 tokens_out: int
24 model: str
25 cost: Decimal
26 latency_ms: float
27 metadata: Dict[str, Any] = field(default_factory=dict)
30@dataclass
31class CostEstimate:
32 """Cost estimation for pipeline execution."""
34 total_cost: Decimal
35 total_tokens: int
36 input_tokens: int
37 output_tokens: int
38 rows: int
39 breakdown_by_stage: Dict[str, Decimal] = field(default_factory=dict)
40 confidence: str = "estimate" # estimate, sample-based, actual
43@dataclass
44class ProcessingStats:
45 """Statistics from pipeline execution."""
47 total_rows: int
48 processed_rows: int
49 failed_rows: int
50 skipped_rows: int
51 rows_per_second: float
52 total_duration_seconds: float
53 stage_durations: Dict[str, float] = field(default_factory=dict)
56@dataclass
57class ErrorInfo:
58 """Information about an error during processing."""
60 row_index: int
61 stage_name: str
62 error_type: str
63 error_message: str
64 timestamp: datetime
65 context: Dict[str, Any] = field(default_factory=dict)
68@dataclass
69class ExecutionResult:
70 """Complete result from pipeline execution."""
72 data: pd.DataFrame
73 metrics: ProcessingStats
74 costs: CostEstimate
75 errors: List[ErrorInfo] = field(default_factory=list)
76 execution_id: UUID = field(default_factory=uuid4)
77 start_time: datetime = field(default_factory=datetime.now)
78 end_time: Optional[datetime] = None
79 success: bool = True
80 metadata: Dict[str, Any] = field(default_factory=dict)
82 @property
83 def duration(self) -> float:
84 """Get execution duration in seconds."""
85 if self.end_time is None:
86 return 0.0
87 return (self.end_time - self.start_time).total_seconds()
89 @property
90 def error_rate(self) -> float:
91 """Get error rate as percentage."""
92 if self.metrics.total_rows == 0:
93 return 0.0
94 return (
95 self.metrics.failed_rows / self.metrics.total_rows
96 ) * 100
99@dataclass
100class ValidationResult:
101 """Result from validation checks."""
103 is_valid: bool
104 errors: List[str] = field(default_factory=list)
105 warnings: List[str] = field(default_factory=list)
107 def add_error(self, error: str) -> None:
108 """Add an error message."""
109 self.errors.append(error)
110 self.is_valid = False
112 def add_warning(self, warning: str) -> None:
113 """Add a warning message."""
114 self.warnings.append(warning)
117@dataclass
118class WriteConfirmation:
119 """Confirmation of successful data write."""
121 path: str
122 rows_written: int
123 success: bool
124 timestamp: datetime = field(default_factory=datetime.now)
125 metadata: Dict[str, Any] = field(default_factory=dict)
128@dataclass
129class CheckpointInfo:
130 """Information about a checkpoint."""
132 session_id: UUID
133 checkpoint_path: str
134 row_index: int
135 stage_index: int
136 timestamp: datetime
137 size_bytes: int
140@dataclass
141class RowMetadata:
142 """Metadata for a single row during processing."""
144 row_index: int
145 row_id: Optional[Any] = None
146 batch_id: Optional[int] = None
147 attempt: int = 1
148 custom: Dict[str, Any] = field(default_factory=dict)
151@dataclass
152class PromptBatch:
153 """Batch of prompts for processing."""
155 prompts: List[str]
156 metadata: List[RowMetadata]
157 batch_id: int
160@dataclass
161class ResponseBatch:
162 """Batch of responses from LLM."""
164 responses: List[str]
165 metadata: List[RowMetadata]
166 tokens_used: int
167 cost: Decimal
168 batch_id: int
169 latencies_ms: List[float] = field(default_factory=list)