Coverage for llm_dataset_engine/orchestration/execution_context.py: 70%
40 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Execution context for carrying runtime state between stages.
4Implements Memento pattern for checkpoint serialization.
5"""
7from dataclasses import dataclass, field
8from datetime import datetime
9from decimal import Decimal
10from typing import Any, Dict
11from uuid import UUID, uuid4
13from llm_dataset_engine.core.models import ProcessingStats
16@dataclass
17class ExecutionContext:
18 """
19 Runtime state container for pipeline execution.
21 Carries shared state between stages and tracks progress.
22 Immutable for most fields to prevent accidental modification.
23 """
25 session_id: UUID = field(default_factory=uuid4)
26 pipeline_id: UUID = field(default_factory=uuid4)
27 start_time: datetime = field(default_factory=datetime.now)
28 end_time: datetime | None = None
30 # Progress tracking
31 current_stage_index: int = 0
32 last_processed_row: int = 0
33 total_rows: int = 0
35 # Cost tracking
36 accumulated_cost: Decimal = field(default_factory=lambda: Decimal("0.0"))
37 accumulated_tokens: int = 0
39 # Intermediate data storage
40 intermediate_data: Dict[str, Any] = field(default_factory=dict)
42 # Statistics
43 failed_rows: int = 0
44 skipped_rows: int = 0
46 def update_stage(self, stage_index: int) -> None:
47 """Update current stage."""
48 self.current_stage_index = stage_index
50 def update_row(self, row_index: int) -> None:
51 """Update last processed row."""
52 self.last_processed_row = row_index
54 def add_cost(self, cost: Decimal, tokens: int) -> None:
55 """Add cost and token usage."""
56 self.accumulated_cost += cost
57 self.accumulated_tokens += tokens
59 def get_progress(self) -> float:
60 """Get completion percentage."""
61 if self.total_rows == 0:
62 return 0.0
63 return (self.last_processed_row / self.total_rows) * 100
65 def get_stats(self) -> ProcessingStats:
66 """Get processing statistics."""
67 duration = (
68 (datetime.now() - self.start_time).total_seconds()
69 if self.end_time is None
70 else (self.end_time - self.start_time).total_seconds()
71 )
73 rows_per_second = (
74 self.last_processed_row / duration if duration > 0 else 0.0
75 )
77 return ProcessingStats(
78 total_rows=self.total_rows,
79 processed_rows=self.last_processed_row,
80 failed_rows=self.failed_rows,
81 skipped_rows=self.skipped_rows,
82 rows_per_second=rows_per_second,
83 total_duration_seconds=duration,
84 )
86 def to_checkpoint(self) -> Dict[str, Any]:
87 """
88 Serialize to checkpoint dictionary (Memento pattern).
90 Returns:
91 Dictionary representation for persistence
92 """
93 return {
94 "session_id": str(self.session_id),
95 "pipeline_id": str(self.pipeline_id),
96 "start_time": self.start_time.isoformat(),
97 "end_time": self.end_time.isoformat() if self.end_time else None,
98 "current_stage_index": self.current_stage_index,
99 "last_processed_row": self.last_processed_row,
100 "total_rows": self.total_rows,
101 "accumulated_cost": str(self.accumulated_cost),
102 "accumulated_tokens": self.accumulated_tokens,
103 "intermediate_data": self.intermediate_data,
104 "failed_rows": self.failed_rows,
105 "skipped_rows": self.skipped_rows,
106 }
108 @classmethod
109 def from_checkpoint(cls, data: Dict[str, Any]) -> "ExecutionContext":
110 """
111 Deserialize from checkpoint dictionary.
113 Args:
114 data: Checkpoint data
116 Returns:
117 Restored ExecutionContext
118 """
119 return cls(
120 session_id=UUID(data["session_id"]),
121 pipeline_id=UUID(data["pipeline_id"]),
122 start_time=datetime.fromisoformat(data["start_time"]),
123 end_time=(
124 datetime.fromisoformat(data["end_time"])
125 if data.get("end_time")
126 else None
127 ),
128 current_stage_index=data["current_stage_index"],
129 last_processed_row=data["last_processed_row"],
130 total_rows=data["total_rows"],
131 accumulated_cost=Decimal(data["accumulated_cost"]),
132 accumulated_tokens=data["accumulated_tokens"],
133 intermediate_data=data.get("intermediate_data", {}),
134 failed_rows=data.get("failed_rows", 0),
135 skipped_rows=data.get("skipped_rows", 0),
136 )