Coverage for llm_dataset_engine/integrations/airflow.py: 0%
57 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Airflow integration - Pre-built operators for Apache Airflow.
4Provides LLMTransformOperator for easy integration into Airflow DAGs.
5"""
7from typing import Any, Dict, Optional
9try:
10 from airflow.models import BaseOperator
11 from airflow.utils.decorators import apply_defaults
13 AIRFLOW_AVAILABLE = True
14except ImportError:
15 AIRFLOW_AVAILABLE = False
16 BaseOperator = object # Placeholder
18from llm_dataset_engine.api import Pipeline
19from llm_dataset_engine.config import ConfigLoader
22if AIRFLOW_AVAILABLE:
24 class LLMTransformOperator(BaseOperator):
25 """
26 Airflow operator for LLM dataset transformations.
28 Integrates LLM Dataset Engine into Airflow DAGs with minimal boilerplate.
30 Example:
31 from llm_dataset_engine.integrations.airflow import LLMTransformOperator
33 llm_task = LLMTransformOperator(
34 task_id='llm_enrichment',
35 config_path='configs/llm_config.yaml',
36 input_xcom_key='raw_data',
37 output_xcom_key='enriched_data',
38 max_budget=10.0,
39 dag=dag,
40 )
41 """
43 @apply_defaults
44 def __init__(
45 self,
46 config_path: str,
47 input_xcom_key: Optional[str] = None,
48 output_xcom_key: str = "llm_result",
49 input_file: Optional[str] = None,
50 output_file: Optional[str] = None,
51 max_budget: Optional[float] = None,
52 provider_override: Optional[str] = None,
53 model_override: Optional[str] = None,
54 *args,
55 **kwargs,
56 ):
57 """
58 Initialize LLM transform operator.
60 Args:
61 config_path: Path to YAML/JSON configuration
62 input_xcom_key: XCom key to pull DataFrame from previous task
63 output_xcom_key: XCom key to push result DataFrame
64 input_file: Path to input file (alternative to XCom)
65 output_file: Path to output file (alternative to XCom)
66 max_budget: Override maximum budget
67 provider_override: Override LLM provider
68 model_override: Override model name
69 *args: Airflow BaseOperator args
70 **kwargs: Airflow BaseOperator kwargs
71 """
72 super().__init__(*args, **kwargs)
73 self.config_path = config_path
74 self.input_xcom_key = input_xcom_key
75 self.output_xcom_key = output_xcom_key
76 self.input_file = input_file
77 self.output_file = output_file
78 self.max_budget = max_budget
79 self.provider_override = provider_override
80 self.model_override = model_override
82 def execute(self, context: Dict[str, Any]) -> Any:
83 """
84 Execute LLM transformation.
86 Args:
87 context: Airflow task context
89 Returns:
90 Result DataFrame (pushed to XCom)
91 """
92 # Load configuration
93 specs = ConfigLoader.from_yaml(self.config_path)
95 # Override settings
96 if self.max_budget is not None:
97 from decimal import Decimal
98 specs.processing.max_budget = Decimal(str(self.max_budget))
100 if self.provider_override:
101 from llm_dataset_engine.core.specifications import LLMProvider
102 specs.llm.provider = LLMProvider(self.provider_override)
104 if self.model_override:
105 specs.llm.model = self.model_override
107 # Get input data
108 if self.input_xcom_key:
109 # Pull from XCom
110 df = context['ti'].xcom_pull(key=self.input_xcom_key)
111 if df is None:
112 raise ValueError(f"No data found in XCom key: {self.input_xcom_key}")
113 pipeline = Pipeline(specs, dataframe=df)
114 elif self.input_file:
115 # Read from file
116 specs.dataset.source_path = self.input_file
117 pipeline = Pipeline(specs)
118 else:
119 raise ValueError("Either input_xcom_key or input_file required")
121 # Set output if specified
122 if self.output_file:
123 from pathlib import Path
124 from llm_dataset_engine.core.specifications import OutputSpec, MergeStrategy, DataSourceType
126 specs.output = OutputSpec(
127 destination_type=DataSourceType.CSV,
128 destination_path=Path(self.output_file),
129 merge_strategy=MergeStrategy.REPLACE,
130 )
132 # Execute pipeline
133 result = pipeline.execute()
135 # Log metrics
136 self.log.info(f"Processed {result.metrics.total_rows} rows")
137 self.log.info(f"Cost: ${result.costs.total_cost}")
138 self.log.info(f"Duration: {result.duration:.2f}s")
140 # Push result to XCom
141 if self.output_xcom_key:
142 context['ti'].xcom_push(
143 key=self.output_xcom_key,
144 value=result.data
145 )
147 return result.data
149else:
150 # Airflow not installed
151 class LLMTransformOperator:
152 """Placeholder when Airflow not installed."""
154 def __init__(self, *args, **kwargs):
155 raise ImportError(
156 "Apache Airflow is required to use LLMTransformOperator. "
157 "Install with: pip install apache-airflow"
158 )
161__all__ = ["LLMTransformOperator"]