Coverage for llm_dataset_engine/integrations/airflow.py: 0%

57 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2Airflow integration - Pre-built operators for Apache Airflow. 

3 

4Provides LLMTransformOperator for easy integration into Airflow DAGs. 

5""" 

6 

7from typing import Any, Dict, Optional 

8 

9try: 

10 from airflow.models import BaseOperator 

11 from airflow.utils.decorators import apply_defaults 

12 

13 AIRFLOW_AVAILABLE = True 

14except ImportError: 

15 AIRFLOW_AVAILABLE = False 

16 BaseOperator = object # Placeholder 

17 

18from llm_dataset_engine.api import Pipeline 

19from llm_dataset_engine.config import ConfigLoader 

20 

21 

22if AIRFLOW_AVAILABLE: 

23 

24 class LLMTransformOperator(BaseOperator): 

25 """ 

26 Airflow operator for LLM dataset transformations. 

27  

28 Integrates LLM Dataset Engine into Airflow DAGs with minimal boilerplate. 

29  

30 Example: 

31 from llm_dataset_engine.integrations.airflow import LLMTransformOperator 

32  

33 llm_task = LLMTransformOperator( 

34 task_id='llm_enrichment', 

35 config_path='configs/llm_config.yaml', 

36 input_xcom_key='raw_data', 

37 output_xcom_key='enriched_data', 

38 max_budget=10.0, 

39 dag=dag, 

40 ) 

41 """ 

42 

43 @apply_defaults 

44 def __init__( 

45 self, 

46 config_path: str, 

47 input_xcom_key: Optional[str] = None, 

48 output_xcom_key: str = "llm_result", 

49 input_file: Optional[str] = None, 

50 output_file: Optional[str] = None, 

51 max_budget: Optional[float] = None, 

52 provider_override: Optional[str] = None, 

53 model_override: Optional[str] = None, 

54 *args, 

55 **kwargs, 

56 ): 

57 """ 

58 Initialize LLM transform operator. 

59 

60 Args: 

61 config_path: Path to YAML/JSON configuration 

62 input_xcom_key: XCom key to pull DataFrame from previous task 

63 output_xcom_key: XCom key to push result DataFrame 

64 input_file: Path to input file (alternative to XCom) 

65 output_file: Path to output file (alternative to XCom) 

66 max_budget: Override maximum budget 

67 provider_override: Override LLM provider 

68 model_override: Override model name 

69 *args: Airflow BaseOperator args 

70 **kwargs: Airflow BaseOperator kwargs 

71 """ 

72 super().__init__(*args, **kwargs) 

73 self.config_path = config_path 

74 self.input_xcom_key = input_xcom_key 

75 self.output_xcom_key = output_xcom_key 

76 self.input_file = input_file 

77 self.output_file = output_file 

78 self.max_budget = max_budget 

79 self.provider_override = provider_override 

80 self.model_override = model_override 

81 

82 def execute(self, context: Dict[str, Any]) -> Any: 

83 """ 

84 Execute LLM transformation. 

85 

86 Args: 

87 context: Airflow task context 

88 

89 Returns: 

90 Result DataFrame (pushed to XCom) 

91 """ 

92 # Load configuration 

93 specs = ConfigLoader.from_yaml(self.config_path) 

94 

95 # Override settings 

96 if self.max_budget is not None: 

97 from decimal import Decimal 

98 specs.processing.max_budget = Decimal(str(self.max_budget)) 

99 

100 if self.provider_override: 

101 from llm_dataset_engine.core.specifications import LLMProvider 

102 specs.llm.provider = LLMProvider(self.provider_override) 

103 

104 if self.model_override: 

105 specs.llm.model = self.model_override 

106 

107 # Get input data 

108 if self.input_xcom_key: 

109 # Pull from XCom 

110 df = context['ti'].xcom_pull(key=self.input_xcom_key) 

111 if df is None: 

112 raise ValueError(f"No data found in XCom key: {self.input_xcom_key}") 

113 pipeline = Pipeline(specs, dataframe=df) 

114 elif self.input_file: 

115 # Read from file 

116 specs.dataset.source_path = self.input_file 

117 pipeline = Pipeline(specs) 

118 else: 

119 raise ValueError("Either input_xcom_key or input_file required") 

120 

121 # Set output if specified 

122 if self.output_file: 

123 from pathlib import Path 

124 from llm_dataset_engine.core.specifications import OutputSpec, MergeStrategy, DataSourceType 

125 

126 specs.output = OutputSpec( 

127 destination_type=DataSourceType.CSV, 

128 destination_path=Path(self.output_file), 

129 merge_strategy=MergeStrategy.REPLACE, 

130 ) 

131 

132 # Execute pipeline 

133 result = pipeline.execute() 

134 

135 # Log metrics 

136 self.log.info(f"Processed {result.metrics.total_rows} rows") 

137 self.log.info(f"Cost: ${result.costs.total_cost}") 

138 self.log.info(f"Duration: {result.duration:.2f}s") 

139 

140 # Push result to XCom 

141 if self.output_xcom_key: 

142 context['ti'].xcom_push( 

143 key=self.output_xcom_key, 

144 value=result.data 

145 ) 

146 

147 return result.data 

148 

149else: 

150 # Airflow not installed 

151 class LLMTransformOperator: 

152 """Placeholder when Airflow not installed.""" 

153 

154 def __init__(self, *args, **kwargs): 

155 raise ImportError( 

156 "Apache Airflow is required to use LLMTransformOperator. " 

157 "Install with: pip install apache-airflow" 

158 ) 

159 

160 

161__all__ = ["LLMTransformOperator"] 

162