Coverage for llm_dataset_engine/stages/data_loader_stage.py: 34%

35 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1"""Data loading stage for reading tabular data.""" 

2 

3from decimal import Decimal 

4from typing import Any 

5 

6import pandas as pd 

7 

8from llm_dataset_engine.adapters.data_io import create_data_reader 

9from llm_dataset_engine.core.models import CostEstimate, ValidationResult 

10from llm_dataset_engine.core.specifications import DatasetSpec 

11from llm_dataset_engine.stages.pipeline_stage import PipelineStage 

12 

13 

14class DataLoaderStage(PipelineStage[DatasetSpec, pd.DataFrame]): 

15 """ 

16 Load data from source and validate schema. 

17  

18 Responsibilities: 

19 - Read data from configured source 

20 - Validate required columns exist 

21 - Apply any filters 

22 - Update context with row count 

23 """ 

24 

25 def __init__(self, dataframe: pd.DataFrame | None = None): 

26 """ 

27 Initialize data loader stage. 

28 

29 Args: 

30 dataframe: Optional pre-loaded dataframe (for DataFrame source) 

31 """ 

32 super().__init__("DataLoader") 

33 self.dataframe = dataframe 

34 

35 def process(self, spec: DatasetSpec, context: Any) -> pd.DataFrame: 

36 """Load data from source.""" 

37 # Create appropriate reader 

38 reader = create_data_reader( 

39 source_type=spec.source_type, 

40 source_path=spec.source_path, 

41 dataframe=self.dataframe, 

42 delimiter=spec.delimiter, 

43 encoding=spec.encoding, 

44 sheet_name=spec.sheet_name, 

45 ) 

46 

47 # Read data 

48 df = reader.read() 

49 

50 # Validate columns exist 

51 missing_cols = set(spec.input_columns) - set(df.columns) 

52 if missing_cols: 

53 raise ValueError(f"Missing columns: {missing_cols}") 

54 

55 # Apply filters if specified 

56 if spec.filters: 

57 for column, value in spec.filters.items(): 

58 if column in df.columns: 

59 df = df[df[column] == value] 

60 

61 # Update context with total rows 

62 context.total_rows = len(df) 

63 

64 self.logger.info(f"Loaded {len(df)} rows from {spec.source_type}") 

65 

66 return df 

67 

68 def validate_input(self, spec: DatasetSpec) -> ValidationResult: 

69 """Validate dataset specification.""" 

70 result = ValidationResult(is_valid=True) 

71 

72 # Check file exists for file sources 

73 if spec.source_path and not spec.source_path.exists(): 

74 result.add_error(f"Source file not found: {spec.source_path}") 

75 

76 # Check input columns specified 

77 if not spec.input_columns: 

78 result.add_error("No input columns specified") 

79 

80 # Check output columns specified 

81 if not spec.output_columns: 

82 result.add_error("No output columns specified") 

83 

84 return result 

85 

86 def estimate_cost(self, spec: DatasetSpec) -> CostEstimate: 

87 """Data loading has no LLM cost.""" 

88 return CostEstimate( 

89 total_cost=Decimal("0.0"), 

90 total_tokens=0, 

91 input_tokens=0, 

92 output_tokens=0, 

93 rows=0, 

94 ) 

95