Coverage for llm_dataset_engine/orchestration/execution_context.py: 70%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2Execution context for carrying runtime state between stages. 

3 

4Implements Memento pattern for checkpoint serialization. 

5""" 

6 

7from dataclasses import dataclass, field 

8from datetime import datetime 

9from decimal import Decimal 

10from typing import Any, Dict 

11from uuid import UUID, uuid4 

12 

13from llm_dataset_engine.core.models import ProcessingStats 

14 

15 

16@dataclass 

17class ExecutionContext: 

18 """ 

19 Runtime state container for pipeline execution. 

20  

21 Carries shared state between stages and tracks progress. 

22 Immutable for most fields to prevent accidental modification. 

23 """ 

24 

25 session_id: UUID = field(default_factory=uuid4) 

26 pipeline_id: UUID = field(default_factory=uuid4) 

27 start_time: datetime = field(default_factory=datetime.now) 

28 end_time: datetime | None = None 

29 

30 # Progress tracking 

31 current_stage_index: int = 0 

32 last_processed_row: int = 0 

33 total_rows: int = 0 

34 

35 # Cost tracking 

36 accumulated_cost: Decimal = field(default_factory=lambda: Decimal("0.0")) 

37 accumulated_tokens: int = 0 

38 

39 # Intermediate data storage 

40 intermediate_data: Dict[str, Any] = field(default_factory=dict) 

41 

42 # Statistics 

43 failed_rows: int = 0 

44 skipped_rows: int = 0 

45 

46 def update_stage(self, stage_index: int) -> None: 

47 """Update current stage.""" 

48 self.current_stage_index = stage_index 

49 

50 def update_row(self, row_index: int) -> None: 

51 """Update last processed row.""" 

52 self.last_processed_row = row_index 

53 

54 def add_cost(self, cost: Decimal, tokens: int) -> None: 

55 """Add cost and token usage.""" 

56 self.accumulated_cost += cost 

57 self.accumulated_tokens += tokens 

58 

59 def get_progress(self) -> float: 

60 """Get completion percentage.""" 

61 if self.total_rows == 0: 

62 return 0.0 

63 return (self.last_processed_row / self.total_rows) * 100 

64 

65 def get_stats(self) -> ProcessingStats: 

66 """Get processing statistics.""" 

67 duration = ( 

68 (datetime.now() - self.start_time).total_seconds() 

69 if self.end_time is None 

70 else (self.end_time - self.start_time).total_seconds() 

71 ) 

72 

73 rows_per_second = ( 

74 self.last_processed_row / duration if duration > 0 else 0.0 

75 ) 

76 

77 return ProcessingStats( 

78 total_rows=self.total_rows, 

79 processed_rows=self.last_processed_row, 

80 failed_rows=self.failed_rows, 

81 skipped_rows=self.skipped_rows, 

82 rows_per_second=rows_per_second, 

83 total_duration_seconds=duration, 

84 ) 

85 

86 def to_checkpoint(self) -> Dict[str, Any]: 

87 """ 

88 Serialize to checkpoint dictionary (Memento pattern). 

89 

90 Returns: 

91 Dictionary representation for persistence 

92 """ 

93 return { 

94 "session_id": str(self.session_id), 

95 "pipeline_id": str(self.pipeline_id), 

96 "start_time": self.start_time.isoformat(), 

97 "end_time": self.end_time.isoformat() if self.end_time else None, 

98 "current_stage_index": self.current_stage_index, 

99 "last_processed_row": self.last_processed_row, 

100 "total_rows": self.total_rows, 

101 "accumulated_cost": str(self.accumulated_cost), 

102 "accumulated_tokens": self.accumulated_tokens, 

103 "intermediate_data": self.intermediate_data, 

104 "failed_rows": self.failed_rows, 

105 "skipped_rows": self.skipped_rows, 

106 } 

107 

108 @classmethod 

109 def from_checkpoint(cls, data: Dict[str, Any]) -> "ExecutionContext": 

110 """ 

111 Deserialize from checkpoint dictionary. 

112 

113 Args: 

114 data: Checkpoint data 

115 

116 Returns: 

117 Restored ExecutionContext 

118 """ 

119 return cls( 

120 session_id=UUID(data["session_id"]), 

121 pipeline_id=UUID(data["pipeline_id"]), 

122 start_time=datetime.fromisoformat(data["start_time"]), 

123 end_time=( 

124 datetime.fromisoformat(data["end_time"]) 

125 if data.get("end_time") 

126 else None 

127 ), 

128 current_stage_index=data["current_stage_index"], 

129 last_processed_row=data["last_processed_row"], 

130 total_rows=data["total_rows"], 

131 accumulated_cost=Decimal(data["accumulated_cost"]), 

132 accumulated_tokens=data["accumulated_tokens"], 

133 intermediate_data=data.get("intermediate_data", {}), 

134 failed_rows=data.get("failed_rows", 0), 

135 skipped_rows=data.get("skipped_rows", 0), 

136 ) 

137