Coverage for src / nanocli / config.py: 88%
106 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 03:47 -0500
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-19 03:47 -0500
1"""Config layer: Recursive config tree with OmegaConf.
3This module provides configuration handling for NanoCLI:
5- `option()` - Dataclass field wrapper with help text
6- `compile_config()` - Pure function to compile configs
7- `load_yaml()` / `to_yaml()` - YAML I/O
8- `parse_overrides()` - Parse CLI overrides
9"""
11from dataclasses import MISSING, field, fields, is_dataclass
12from pathlib import Path
13from typing import Any, TypeVar
15from omegaconf import DictConfig, OmegaConf
17T = TypeVar("T")
20class ConfigError(Exception):
21 """Configuration-related errors.
23 Raised when config files are missing, overrides are invalid, etc.
24 """
27def option(
28 default: Any = MISSING,
29 *,
30 help: str = "",
31 **kwargs: Any,
32) -> Any:
33 """Dataclass field wrapper with help text for CLI.
35 Use this instead of `field()` to add help text that appears in CLI help.
37 Args:
38 default: Default value for the field.
39 help: Help text shown in CLI.
40 **kwargs: Additional arguments passed to `dataclasses.field()`.
42 Returns:
43 A dataclass field with metadata.
45 Examples:
46 >>> from dataclasses import dataclass
47 >>> @dataclass
48 ... class Config:
49 ... epochs: int = option(100, help="Number of epochs")
50 ... lr: float = option(0.001, help="Learning rate")
51 >>> cfg = Config()
52 >>> cfg.epochs
53 100
54 """
55 metadata = kwargs.pop("metadata", {})
56 metadata["help"] = help
57 return field(default=default, metadata=metadata, **kwargs)
60def load_yaml(path: str | Path) -> DictConfig:
61 """Load a YAML file into a DictConfig.
63 Args:
64 path: Path to the YAML file.
66 Returns:
67 DictConfig containing the parsed YAML.
69 Raises:
70 ConfigError: If the file does not exist.
72 Examples:
73 >>> import tempfile
74 >>> from pathlib import Path
75 >>> with tempfile.NamedTemporaryFile(suffix=".yml", delete=False, mode="w") as f:
76 ... _ = f.write("name: test\\ncount: 42")
77 ... path = f.name
78 >>> cfg = load_yaml(path)
79 >>> cfg.name
80 'test'
81 >>> Path(path).unlink()
82 """
83 path = Path(path)
84 if not path.exists():
85 raise ConfigError(f"Config file not found: {path}")
86 return OmegaConf.load(path) # type: ignore[return-value]
89def parse_overrides(overrides: list[str]) -> DictConfig:
90 """Parse CLI overrides into a config tree.
92 Supports three types of overrides:
93 - `key=value` - Scalar override
94 - `key.path=value` - Nested override
95 - `key=@file.yml` - Subtree replacement from file
97 Args:
98 overrides: List of override strings.
100 Returns:
101 DictConfig with parsed overrides.
103 Raises:
104 ConfigError: If an override doesn't contain '='.
106 Examples:
107 >>> cfg = parse_overrides(["name=test", "count=42"])
108 >>> cfg.name
109 'test'
110 >>> cfg.count
111 42
112 >>> cfg = parse_overrides(["model.layers=24"])
113 >>> cfg.model.layers
114 24
115 """
116 result: dict[str, Any] = {}
118 for override in overrides:
119 if "=" not in override:
120 raise ConfigError(f"Invalid override: '{override}'. Expected 'key=value' format.")
122 key, value = override.split("=", 1)
123 key = key.strip()
124 value = value.strip()
126 # Handle @file syntax for subtree replacement
127 if value.startswith("@"): 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true
128 file_path = value[1:]
129 parsed = OmegaConf.to_container(load_yaml(file_path))
130 else:
131 parsed = _parse_value(value)
133 # Build nested dict from dot notation
134 _set_nested(result, key.split("."), parsed)
136 return OmegaConf.create(result)
139def _parse_value(value: str) -> Any:
140 """Parse a string value into Python type.
142 Args:
143 value: String to parse.
145 Returns:
146 Parsed Python value (bool, None, int, float, list, or str).
148 Examples:
149 >>> _parse_value("true")
150 True
151 >>> _parse_value("42")
152 42
153 >>> _parse_value("3.14")
154 3.14
155 >>> _parse_value("[1, 2, 3]")
156 [1, 2, 3]
157 """
158 # Boolean
159 if value.lower() == "true":
160 return True
161 if value.lower() == "false":
162 return False
164 # None
165 if value.lower() in ("none", "null"):
166 return None
168 # List
169 if value.startswith("[") and value.endswith("]"):
170 inner = value[1:-1].strip()
171 if not inner:
172 return []
173 return [_parse_value(v.strip()) for v in inner.split(",")]
175 # Integer
176 try:
177 return int(value)
178 except ValueError:
179 pass
181 # Float
182 try:
183 return float(value)
184 except ValueError:
185 pass
187 # Quoted string
188 if (value.startswith('"') and value.endswith('"')) or ( 188 ↛ 191line 188 didn't jump to line 191 because the condition on line 188 was never true
189 value.startswith("'") and value.endswith("'")
190 ):
191 return value[1:-1]
193 return value
196def _set_nested(d: dict[str, Any], keys: list[str], value: Any) -> None:
197 """Set a value in a nested dict using a list of keys.
199 Args:
200 d: Dictionary to modify.
201 keys: List of keys forming the path.
202 value: Value to set.
204 Examples:
205 >>> d = {}
206 >>> _set_nested(d, ["a", "b", "c"], 42)
207 >>> d
208 {'a': {'b': {'c': 42}}}
209 """
210 for key in keys[:-1]:
211 d = d.setdefault(key, {})
212 d[keys[-1]] = value
215def compile_config(
216 base: DictConfig | None = None,
217 overrides: list[str] | None = None,
218 schema: type[T] | None = None,
219) -> DictConfig | T:
220 """Compile a config from base + overrides.
222 This is the core function: pure tree rewrite.
223 Priority: schema defaults < base < overrides
225 Args:
226 base: Base config tree (from YAML).
227 overrides: List of override strings (`key=value`, `key=@file`).
228 schema: Optional dataclass for type validation.
230 Returns:
231 Compiled config. Typed if schema provided, else DictConfig.
233 Examples:
234 >>> from dataclasses import dataclass
235 >>> @dataclass
236 ... class Config:
237 ... name: str = "default"
238 ... count: int = 1
239 >>> cfg = compile_config(schema=Config)
240 >>> cfg.name
241 'default'
242 >>> cfg = compile_config(schema=Config, overrides=["name=custom"])
243 >>> cfg.name
244 'custom'
245 """
246 # Build config: schema defaults -> base -> overrides
247 if schema is not None:
248 cfg = OmegaConf.structured(schema)
249 if base is not None:
250 cfg = OmegaConf.merge(cfg, base)
251 else:
252 cfg = base if base is not None else OmegaConf.create({})
254 # Apply overrides (tree rewrite)
255 if overrides:
256 override_cfg = parse_overrides(overrides)
257 try:
258 cfg = OmegaConf.merge(cfg, override_cfg)
259 except Exception as e:
260 # Extract the key from OmegaConf error message
261 error_msg = str(e)
262 if "Key" in error_msg and "not in" in error_msg: 262 ↛ 273line 262 didn't jump to line 273 because the condition on line 262 was always true
263 # Parse: Key 'typer' not in 'ModelConfig'
264 import re
266 match = re.search(r"Key '(\w+)' not in '(\w+)'", error_msg)
267 if match: 267 ↛ 273line 267 didn't jump to line 273 because the condition on line 267 was always true
268 key, cls = match.groups()
269 raise ConfigError(
270 f"Invalid config key '{key}' in '{cls}'. Check for typos in your overrides."
271 ) from None
272 # Re-raise with friendlier message
273 raise ConfigError(f"Config error: {error_msg}") from None
275 # Convert to typed object if schema provided
276 if schema is not None:
277 return OmegaConf.to_object(cfg) # type: ignore[return-value]
279 return cfg # type: ignore[no-any-return]
282def to_yaml(config: Any) -> str:
283 """Convert config to YAML string.
285 Args:
286 config: Config object (dataclass, dict, or DictConfig).
288 Returns:
289 YAML string representation.
291 Examples:
292 >>> from dataclasses import dataclass
293 >>> @dataclass
294 ... class Config:
295 ... name: str = "test"
296 >>> yaml_str = to_yaml(Config())
297 >>> "name: test" in yaml_str
298 True
299 """
300 if is_dataclass(config) and not isinstance(config, type):
301 cfg = OmegaConf.structured(config)
302 elif isinstance(config, DictConfig): 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true
303 cfg = config
304 else:
305 cfg = OmegaConf.create(config)
307 return OmegaConf.to_yaml(cfg)
310def save_yaml(config: Any, path: str | Path) -> None:
311 """Save config to YAML file.
313 Args:
314 config: Config object to save.
315 path: Path to write the YAML file.
317 Examples:
318 >>> import tempfile
319 >>> from dataclasses import dataclass
320 >>> @dataclass
321 ... class Config:
322 ... name: str = "test"
323 >>> with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as f:
324 ... save_yaml(Config(), f.name)
325 ... content = open(f.name).read()
326 >>> "name: test" in content
327 True
328 """
329 Path(path).write_text(to_yaml(config))
332# --- Schema introspection for CLI help ---
335def get_flattened_config_options(
336 schema: type,
337 prefix: str = "",
338) -> list[tuple[str, str, Any, str]]:
339 """Recursively flatten nested dataclasses into dotted paths.
341 Args:
342 schema: Dataclass type to introspect.
343 prefix: Current path prefix (for recursion).
345 Returns:
346 List of tuples: (dotted_name, type_name, default, help_text).
348 Examples:
349 >>> from dataclasses import dataclass
350 >>> @dataclass
351 ... class Config:
352 ... name: str = option("default", help="Name field")
353 >>> opts = get_flattened_config_options(Config)
354 >>> opts[0][0]
355 'name'
356 """
357 if not is_dataclass(schema): 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true
358 return []
360 result = []
362 for f in fields(schema):
363 name = f.name
364 full_name = f"{prefix}.{name}" if prefix else name
365 help_text = f.metadata.get("help", "")
366 field_type = f.type
368 if is_dataclass(field_type): 368 ↛ 370line 368 didn't jump to line 370 because the condition on line 368 was never true
369 # Get help from nested dataclass docstring if not provided
370 if not help_text and field_type.__doc__:
371 help_text = field_type.__doc__.strip().split("\n")[0]
372 # Recurse into nested dataclass
373 result.extend(get_flattened_config_options(field_type, prefix=full_name)) # type: ignore[arg-type]
374 else:
375 default = f.default if f.default is not MISSING else None
376 type_name = getattr(field_type, "__name__", str(field_type))
377 result.append((full_name, type_name, default, help_text))
379 return result
382def get_schema_structure(schema: type) -> tuple[dict[str, type], list[tuple[str, str, Any, str]]]:
383 """Inspect schema to get flattened config options.
385 Args:
386 schema: Dataclass type to introspect.
388 Returns:
389 Tuple of (subcommands, config_options):
390 - subcommands: Always empty dict (for backward compatibility)
391 - config_options: List of (dotted_name, type_name, default, help)
392 """
393 return {}, get_flattened_config_options(schema)