Coverage for llm_dataset_engine/stages/pipeline_stage.py: 41%
46 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Base pipeline stage abstraction.
4Defines the contract for all processing stages using Template Method
5pattern for execution flow.
6"""
8from abc import ABC, abstractmethod
9from typing import Any, Generic, TypeVar
11from llm_dataset_engine.core.models import CostEstimate, ValidationResult
12from llm_dataset_engine.utils import get_logger
14# Type variables for input and output
15TInput = TypeVar("TInput")
16TOutput = TypeVar("TOutput")
18logger = get_logger(__name__)
21class PipelineStage(ABC, Generic[TInput, TOutput]):
22 """
23 Abstract base class for all pipeline stages.
25 Implements Template Method pattern with hooks for extensibility.
26 All stages follow Single Responsibility and are composable.
27 """
29 def __init__(self, name: str):
30 """
31 Initialize pipeline stage.
33 Args:
34 name: Human-readable stage name
35 """
36 self.name = name
37 self.logger = get_logger(f"{__name__}.{name}")
39 @abstractmethod
40 def process(self, input_data: TInput, context: Any) -> TOutput:
41 """
42 Core processing logic (must be implemented by subclasses).
44 Args:
45 input_data: Input data for this stage
46 context: Execution context with shared state
48 Returns:
49 Processed output data
50 """
51 pass
53 @abstractmethod
54 def validate_input(self, input_data: TInput) -> ValidationResult:
55 """
56 Validate input before processing.
58 Args:
59 input_data: Input to validate
61 Returns:
62 ValidationResult with errors/warnings
63 """
64 pass
66 def execute(self, input_data: TInput, context: Any) -> TOutput:
67 """
68 Execute stage with pre/post hooks (Template Method).
70 This method orchestrates the execution flow and should not
71 be overridden.
73 Args:
74 input_data: Input data
75 context: Execution context
77 Returns:
78 Processed output
80 Raises:
81 ValueError: If input validation fails
82 """
83 self.logger.info(f"Starting stage: {self.name}")
85 # Pre-processing hook
86 self.before_process(context)
88 # Validate input
89 validation = self.validate_input(input_data)
90 if not validation.is_valid:
91 error_msg = f"Input validation failed: {validation.errors}"
92 self.logger.error(error_msg)
93 raise ValueError(error_msg)
95 if validation.warnings:
96 for warning in validation.warnings:
97 self.logger.warning(warning)
99 # Core processing
100 try:
101 result = self.process(input_data, context)
102 self.logger.info(f"Completed stage: {self.name}")
104 # Post-processing hook
105 self.after_process(result, context)
107 return result
108 except Exception as e:
109 self.logger.error(f"Stage {self.name} failed: {e}")
110 error_decision = self.on_error(e, context)
111 raise error_decision
113 def before_process(self, context: Any) -> None:
114 """
115 Hook called before processing (default: no-op).
117 Args:
118 context: Execution context
119 """
120 pass
122 def after_process(self, result: TOutput, context: Any) -> None:
123 """
124 Hook called after successful processing (default: no-op).
126 Args:
127 result: Processing result
128 context: Execution context
129 """
130 pass
132 def on_error(self, error: Exception, context: Any) -> Exception:
133 """
134 Hook called on processing error (default: re-raise).
136 Args:
137 error: The exception that occurred
138 context: Execution context
140 Returns:
141 Exception to raise (can transform error)
142 """
143 return error
145 @abstractmethod
146 def estimate_cost(self, input_data: TInput) -> CostEstimate:
147 """
148 Estimate processing cost for this stage.
150 Args:
151 input_data: Input data to estimate for
153 Returns:
154 Cost estimate
155 """
156 pass