Coverage for llm_dataset_engine/stages/result_writer_stage.py: 25%
56 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""Result writing stage for persisting output."""
3from decimal import Decimal
4from pathlib import Path
5from typing import Any
7import pandas as pd
9from llm_dataset_engine.adapters.data_io import create_data_writer
10from llm_dataset_engine.core.models import (
11 CostEstimate,
12 ValidationResult,
13 WriteConfirmation,
14)
15from llm_dataset_engine.core.specifications import (
16 DataSourceType,
17 MergeStrategy,
18 OutputSpec,
19)
20from llm_dataset_engine.stages.pipeline_stage import PipelineStage
23class ResultWriterStage(
24 PipelineStage[
25 tuple[pd.DataFrame, pd.DataFrame, OutputSpec], WriteConfirmation
26 ]
27):
28 """
29 Write results to destination with merge support.
31 Responsibilities:
32 - Merge results with original data
33 - Write to configured destination
34 - Support atomic writes
35 - Return confirmation
36 """
38 def __init__(self):
39 """Initialize result writer stage."""
40 super().__init__("ResultWriter")
42 def process(
43 self,
44 input_data: tuple[pd.DataFrame, pd.DataFrame, OutputSpec],
45 context: Any,
46 ) -> WriteConfirmation:
47 """Write results to destination."""
48 original_df, results_df, output_spec = input_data
50 # Merge results with original data
51 merged_df = self._merge_results(
52 original_df, results_df, output_spec.merge_strategy
53 )
55 # Write to destination
56 if output_spec.destination_path:
57 writer = create_data_writer(output_spec.destination_type)
59 if output_spec.atomic_write:
60 confirmation = writer.atomic_write(
61 merged_df, output_spec.destination_path
62 )
63 else:
64 confirmation = writer.write(
65 merged_df, output_spec.destination_path
66 )
68 self.logger.info(
69 f"Wrote {confirmation.rows_written} rows to "
70 f"{confirmation.path}"
71 )
73 return confirmation
74 else:
75 # No destination specified, return in-memory confirmation
76 return WriteConfirmation(
77 path="<in-memory>",
78 rows_written=len(merged_df),
79 success=True,
80 )
82 def _merge_results(
83 self,
84 original: pd.DataFrame,
85 results: pd.DataFrame,
86 strategy: MergeStrategy,
87 ) -> pd.DataFrame:
88 """Merge results with original data."""
89 if strategy == MergeStrategy.REPLACE:
90 # Replace existing columns or add new ones
91 merged = original.copy()
92 for col in results.columns:
93 merged[col] = results[col]
94 return merged
96 elif strategy == MergeStrategy.APPEND:
97 # Add as new columns (error if exists)
98 for col in results.columns:
99 if col in original.columns:
100 raise ValueError(f"Column {col} already exists")
101 return pd.concat([original, results], axis=1)
103 elif strategy == MergeStrategy.UPDATE:
104 # Only update rows that changed
105 merged = original.copy()
106 for col in results.columns:
107 if col in merged.columns:
108 # Update only non-null values
109 mask = results[col].notna()
110 merged.loc[mask, col] = results.loc[mask, col]
111 else:
112 merged[col] = results[col]
113 return merged
115 else:
116 raise ValueError(f"Unknown merge strategy: {strategy}")
118 def validate_input(
119 self,
120 input_data: tuple[pd.DataFrame, pd.DataFrame, OutputSpec],
121 ) -> ValidationResult:
122 """Validate input data and output specification."""
123 result = ValidationResult(is_valid=True)
125 original_df, results_df, output_spec = input_data
127 if original_df.empty:
128 result.add_warning("Original DataFrame is empty")
130 if results_df.empty:
131 result.add_error("Results DataFrame is empty")
133 # Check destination path if specified
134 if output_spec.destination_path:
135 dest_dir = output_spec.destination_path.parent
136 if not dest_dir.exists():
137 result.add_warning(
138 f"Destination directory does not exist: {dest_dir}"
139 )
141 return result
143 def estimate_cost(
144 self,
145 input_data: tuple[pd.DataFrame, pd.DataFrame, OutputSpec],
146 ) -> CostEstimate:
147 """Result writing has no LLM cost."""
148 return CostEstimate(
149 total_cost=Decimal("0.0"),
150 total_tokens=0,
151 input_tokens=0,
152 output_tokens=0,
153 rows=len(input_data[1]),
154 )