Coverage for llm_dataset_engine/core/specifications.py: 89%
113 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Core specification models for pipeline configuration.
4These Pydantic models define the configuration contracts for all pipeline
5components, following the principle of separation between configuration
6(what to do) and execution (how to do it).
7"""
9from decimal import Decimal
10from enum import Enum
11from pathlib import Path
12from typing import Any, Dict, List, Optional, Union
14from pydantic import BaseModel, ConfigDict, Field, field_validator
17class DataSourceType(str, Enum):
18 """Supported data source types."""
20 CSV = "csv"
21 EXCEL = "excel"
22 PARQUET = "parquet"
23 DATAFRAME = "dataframe"
26class LLMProvider(str, Enum):
27 """Supported LLM providers."""
29 OPENAI = "openai"
30 AZURE_OPENAI = "azure_openai"
31 ANTHROPIC = "anthropic"
32 GROQ = "groq"
35class ErrorPolicy(str, Enum):
36 """Error handling policies for processing failures."""
38 RETRY = "retry"
39 SKIP = "skip"
40 FAIL = "fail"
41 USE_DEFAULT = "use_default"
44class MergeStrategy(str, Enum):
45 """Output merge strategies."""
47 REPLACE = "replace"
48 APPEND = "append"
49 UPDATE = "update"
52class DatasetSpec(BaseModel):
53 """Specification for data source configuration."""
55 source_type: DataSourceType
56 source_path: Optional[Union[str, Path]] = None
57 input_columns: List[str] = Field(
58 ..., min_length=1, description="Columns to use as input"
59 )
60 output_columns: List[str] = Field(
61 ..., min_length=1, description="Columns to store results"
62 )
63 filters: Optional[Dict[str, Any]] = Field(
64 default=None, description="Optional data filters"
65 )
66 sheet_name: Optional[Union[str, int]] = Field(
67 default=0, description="Sheet name for Excel files"
68 )
69 delimiter: str = Field(default=",", description="CSV delimiter")
70 encoding: str = Field(default="utf-8", description="File encoding")
72 @field_validator("source_path")
73 @classmethod
74 def validate_source_path(
75 cls, v: Optional[Union[str, Path]]
76 ) -> Optional[Path]:
77 """Convert string paths to Path objects."""
78 if v is None:
79 return None
80 return Path(v) if isinstance(v, str) else v
82 @field_validator("output_columns")
83 @classmethod
84 def validate_no_overlap(cls, v: List[str], info: Any) -> List[str]:
85 """Ensure output columns don't overlap with input columns."""
86 if "input_columns" in info.data:
87 input_cols = set(info.data["input_columns"])
88 output_cols = set(v)
89 overlap = input_cols & output_cols
90 if overlap:
91 raise ValueError(
92 f"Output columns overlap with input: {overlap}"
93 )
94 return v
97class PromptSpec(BaseModel):
98 """Specification for prompt template configuration."""
100 template: str = Field(..., min_length=1, description="Prompt template")
101 system_message: Optional[str] = Field(
102 default=None, description="System message for LLM"
103 )
104 few_shot_examples: Optional[List[Dict[str, str]]] = Field(
105 default=None, description="Few-shot learning examples"
106 )
107 template_variables: Optional[List[str]] = Field(
108 default=None, description="Expected template variables"
109 )
111 @field_validator("template")
112 @classmethod
113 def validate_template(cls, v: str) -> str:
114 """Validate template has at least one variable."""
115 if "{" not in v or "}" not in v:
116 raise ValueError(
117 "Template must contain at least one variable in {var} format"
118 )
119 return v
122class LLMSpec(BaseModel):
123 """Specification for LLM provider configuration."""
125 provider: LLMProvider
126 model: str = Field(..., min_length=1, description="Model identifier")
127 api_key: Optional[str] = Field(
128 default=None, description="API key (or from env)"
129 )
130 temperature: float = Field(
131 default=0.0, ge=0.0, le=2.0, description="Sampling temperature"
132 )
133 max_tokens: Optional[int] = Field(
134 default=None, gt=0, description="Max output tokens"
135 )
136 top_p: float = Field(
137 default=1.0, ge=0.0, le=1.0, description="Nucleus sampling"
138 )
140 # Azure-specific
141 azure_endpoint: Optional[str] = Field(
142 default=None, description="Azure OpenAI endpoint"
143 )
144 azure_deployment: Optional[str] = Field(
145 default=None, description="Azure deployment name"
146 )
147 api_version: Optional[str] = Field(
148 default="2024-02-15-preview", description="Azure API version"
149 )
151 # Cost tracking
152 input_cost_per_1k_tokens: Optional[Decimal] = Field(
153 default=None, description="Input token cost"
154 )
155 output_cost_per_1k_tokens: Optional[Decimal] = Field(
156 default=None, description="Output token cost"
157 )
159 @field_validator("azure_endpoint", "azure_deployment")
160 @classmethod
161 def validate_azure_config(
162 cls, v: Optional[str], info: Any
163 ) -> Optional[str]:
164 """Validate Azure-specific configuration."""
165 if info.data.get("provider") == LLMProvider.AZURE_OPENAI:
166 if v is None:
167 field_name = info.field_name
168 raise ValueError(
169 f"{field_name} required for Azure OpenAI provider"
170 )
171 return v
174class ProcessingSpec(BaseModel):
175 """Specification for processing parameters."""
177 batch_size: int = Field(
178 default=100, gt=0, le=1000, description="Rows per batch"
179 )
180 concurrency: int = Field(
181 default=5, gt=0, le=20, description="Parallel requests"
182 )
183 checkpoint_interval: int = Field(
184 default=500, gt=0, description="Checkpoint frequency"
185 )
186 max_retries: int = Field(
187 default=3, ge=0, description="Max retry attempts"
188 )
189 retry_delay: float = Field(
190 default=1.0, ge=0.0, description="Initial retry delay (seconds)"
191 )
192 error_policy: ErrorPolicy = Field(
193 default=ErrorPolicy.SKIP, description="Error handling policy"
194 )
195 rate_limit_rpm: Optional[int] = Field(
196 default=None, gt=0, description="Requests per minute limit"
197 )
198 max_budget: Optional[Decimal] = Field(
199 default=None, gt=0, description="Maximum budget in USD"
200 )
201 checkpoint_dir: Path = Field(
202 default=Path(".checkpoints"), description="Checkpoint directory"
203 )
205 @field_validator("checkpoint_dir")
206 @classmethod
207 def validate_checkpoint_dir(cls, v: Union[str, Path]) -> Path:
208 """Convert string paths to Path objects."""
209 return Path(v) if isinstance(v, str) else v
212class OutputSpec(BaseModel):
213 """Specification for output configuration."""
215 destination_type: DataSourceType
216 destination_path: Optional[Path] = None
217 merge_strategy: MergeStrategy = Field(
218 default=MergeStrategy.REPLACE, description="Output merge strategy"
219 )
220 atomic_write: bool = Field(
221 default=True, description="Use atomic writes"
222 )
224 @field_validator("destination_path")
225 @classmethod
226 def validate_destination_path(
227 cls, v: Optional[Union[str, Path]]
228 ) -> Optional[Path]:
229 """Convert string paths to Path objects."""
230 if v is None:
231 return None
232 return Path(v) if isinstance(v, str) else v
235class PipelineSpecifications(BaseModel):
236 """Container for all pipeline specifications."""
238 model_config = ConfigDict(
239 arbitrary_types_allowed=True,
240 validate_assignment=True,
241 )
243 dataset: DatasetSpec
244 prompt: PromptSpec
245 llm: LLMSpec
246 processing: ProcessingSpec = Field(default_factory=ProcessingSpec)
247 output: Optional[OutputSpec] = None
248 metadata: Dict[str, Any] = Field(
249 default_factory=dict, description="Custom metadata"
250 )