Coverage for llm_dataset_engine/core/specifications.py: 89%

113 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2Core specification models for pipeline configuration. 

3 

4These Pydantic models define the configuration contracts for all pipeline 

5components, following the principle of separation between configuration 

6(what to do) and execution (how to do it). 

7""" 

8 

9from decimal import Decimal 

10from enum import Enum 

11from pathlib import Path 

12from typing import Any, Dict, List, Optional, Union 

13 

14from pydantic import BaseModel, ConfigDict, Field, field_validator 

15 

16 

17class DataSourceType(str, Enum): 

18 """Supported data source types.""" 

19 

20 CSV = "csv" 

21 EXCEL = "excel" 

22 PARQUET = "parquet" 

23 DATAFRAME = "dataframe" 

24 

25 

26class LLMProvider(str, Enum): 

27 """Supported LLM providers.""" 

28 

29 OPENAI = "openai" 

30 AZURE_OPENAI = "azure_openai" 

31 ANTHROPIC = "anthropic" 

32 GROQ = "groq" 

33 

34 

35class ErrorPolicy(str, Enum): 

36 """Error handling policies for processing failures.""" 

37 

38 RETRY = "retry" 

39 SKIP = "skip" 

40 FAIL = "fail" 

41 USE_DEFAULT = "use_default" 

42 

43 

44class MergeStrategy(str, Enum): 

45 """Output merge strategies.""" 

46 

47 REPLACE = "replace" 

48 APPEND = "append" 

49 UPDATE = "update" 

50 

51 

52class DatasetSpec(BaseModel): 

53 """Specification for data source configuration.""" 

54 

55 source_type: DataSourceType 

56 source_path: Optional[Union[str, Path]] = None 

57 input_columns: List[str] = Field( 

58 ..., min_length=1, description="Columns to use as input" 

59 ) 

60 output_columns: List[str] = Field( 

61 ..., min_length=1, description="Columns to store results" 

62 ) 

63 filters: Optional[Dict[str, Any]] = Field( 

64 default=None, description="Optional data filters" 

65 ) 

66 sheet_name: Optional[Union[str, int]] = Field( 

67 default=0, description="Sheet name for Excel files" 

68 ) 

69 delimiter: str = Field(default=",", description="CSV delimiter") 

70 encoding: str = Field(default="utf-8", description="File encoding") 

71 

72 @field_validator("source_path") 

73 @classmethod 

74 def validate_source_path( 

75 cls, v: Optional[Union[str, Path]] 

76 ) -> Optional[Path]: 

77 """Convert string paths to Path objects.""" 

78 if v is None: 

79 return None 

80 return Path(v) if isinstance(v, str) else v 

81 

82 @field_validator("output_columns") 

83 @classmethod 

84 def validate_no_overlap(cls, v: List[str], info: Any) -> List[str]: 

85 """Ensure output columns don't overlap with input columns.""" 

86 if "input_columns" in info.data: 

87 input_cols = set(info.data["input_columns"]) 

88 output_cols = set(v) 

89 overlap = input_cols & output_cols 

90 if overlap: 

91 raise ValueError( 

92 f"Output columns overlap with input: {overlap}" 

93 ) 

94 return v 

95 

96 

97class PromptSpec(BaseModel): 

98 """Specification for prompt template configuration.""" 

99 

100 template: str = Field(..., min_length=1, description="Prompt template") 

101 system_message: Optional[str] = Field( 

102 default=None, description="System message for LLM" 

103 ) 

104 few_shot_examples: Optional[List[Dict[str, str]]] = Field( 

105 default=None, description="Few-shot learning examples" 

106 ) 

107 template_variables: Optional[List[str]] = Field( 

108 default=None, description="Expected template variables" 

109 ) 

110 

111 @field_validator("template") 

112 @classmethod 

113 def validate_template(cls, v: str) -> str: 

114 """Validate template has at least one variable.""" 

115 if "{" not in v or "}" not in v: 

116 raise ValueError( 

117 "Template must contain at least one variable in {var} format" 

118 ) 

119 return v 

120 

121 

122class LLMSpec(BaseModel): 

123 """Specification for LLM provider configuration.""" 

124 

125 provider: LLMProvider 

126 model: str = Field(..., min_length=1, description="Model identifier") 

127 api_key: Optional[str] = Field( 

128 default=None, description="API key (or from env)" 

129 ) 

130 temperature: float = Field( 

131 default=0.0, ge=0.0, le=2.0, description="Sampling temperature" 

132 ) 

133 max_tokens: Optional[int] = Field( 

134 default=None, gt=0, description="Max output tokens" 

135 ) 

136 top_p: float = Field( 

137 default=1.0, ge=0.0, le=1.0, description="Nucleus sampling" 

138 ) 

139 

140 # Azure-specific 

141 azure_endpoint: Optional[str] = Field( 

142 default=None, description="Azure OpenAI endpoint" 

143 ) 

144 azure_deployment: Optional[str] = Field( 

145 default=None, description="Azure deployment name" 

146 ) 

147 api_version: Optional[str] = Field( 

148 default="2024-02-15-preview", description="Azure API version" 

149 ) 

150 

151 # Cost tracking 

152 input_cost_per_1k_tokens: Optional[Decimal] = Field( 

153 default=None, description="Input token cost" 

154 ) 

155 output_cost_per_1k_tokens: Optional[Decimal] = Field( 

156 default=None, description="Output token cost" 

157 ) 

158 

159 @field_validator("azure_endpoint", "azure_deployment") 

160 @classmethod 

161 def validate_azure_config( 

162 cls, v: Optional[str], info: Any 

163 ) -> Optional[str]: 

164 """Validate Azure-specific configuration.""" 

165 if info.data.get("provider") == LLMProvider.AZURE_OPENAI: 

166 if v is None: 

167 field_name = info.field_name 

168 raise ValueError( 

169 f"{field_name} required for Azure OpenAI provider" 

170 ) 

171 return v 

172 

173 

174class ProcessingSpec(BaseModel): 

175 """Specification for processing parameters.""" 

176 

177 batch_size: int = Field( 

178 default=100, gt=0, le=1000, description="Rows per batch" 

179 ) 

180 concurrency: int = Field( 

181 default=5, gt=0, le=20, description="Parallel requests" 

182 ) 

183 checkpoint_interval: int = Field( 

184 default=500, gt=0, description="Checkpoint frequency" 

185 ) 

186 max_retries: int = Field( 

187 default=3, ge=0, description="Max retry attempts" 

188 ) 

189 retry_delay: float = Field( 

190 default=1.0, ge=0.0, description="Initial retry delay (seconds)" 

191 ) 

192 error_policy: ErrorPolicy = Field( 

193 default=ErrorPolicy.SKIP, description="Error handling policy" 

194 ) 

195 rate_limit_rpm: Optional[int] = Field( 

196 default=None, gt=0, description="Requests per minute limit" 

197 ) 

198 max_budget: Optional[Decimal] = Field( 

199 default=None, gt=0, description="Maximum budget in USD" 

200 ) 

201 checkpoint_dir: Path = Field( 

202 default=Path(".checkpoints"), description="Checkpoint directory" 

203 ) 

204 

205 @field_validator("checkpoint_dir") 

206 @classmethod 

207 def validate_checkpoint_dir(cls, v: Union[str, Path]) -> Path: 

208 """Convert string paths to Path objects.""" 

209 return Path(v) if isinstance(v, str) else v 

210 

211 

212class OutputSpec(BaseModel): 

213 """Specification for output configuration.""" 

214 

215 destination_type: DataSourceType 

216 destination_path: Optional[Path] = None 

217 merge_strategy: MergeStrategy = Field( 

218 default=MergeStrategy.REPLACE, description="Output merge strategy" 

219 ) 

220 atomic_write: bool = Field( 

221 default=True, description="Use atomic writes" 

222 ) 

223 

224 @field_validator("destination_path") 

225 @classmethod 

226 def validate_destination_path( 

227 cls, v: Optional[Union[str, Path]] 

228 ) -> Optional[Path]: 

229 """Convert string paths to Path objects.""" 

230 if v is None: 

231 return None 

232 return Path(v) if isinstance(v, str) else v 

233 

234 

235class PipelineSpecifications(BaseModel): 

236 """Container for all pipeline specifications.""" 

237 

238 model_config = ConfigDict( 

239 arbitrary_types_allowed=True, 

240 validate_assignment=True, 

241 ) 

242 

243 dataset: DatasetSpec 

244 prompt: PromptSpec 

245 llm: LLMSpec 

246 processing: ProcessingSpec = Field(default_factory=ProcessingSpec) 

247 output: Optional[OutputSpec] = None 

248 metadata: Dict[str, Any] = Field( 

249 default_factory=dict, description="Custom metadata" 

250 ) 

251