Coverage for src / nanocli / config.py: 88%

106 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-19 03:47 -0500

1"""Config layer: Recursive config tree with OmegaConf. 

2 

3This module provides configuration handling for NanoCLI: 

4 

5- `option()` - Dataclass field wrapper with help text 

6- `compile_config()` - Pure function to compile configs 

7- `load_yaml()` / `to_yaml()` - YAML I/O 

8- `parse_overrides()` - Parse CLI overrides 

9""" 

10 

11from dataclasses import MISSING, field, fields, is_dataclass 

12from pathlib import Path 

13from typing import Any, TypeVar 

14 

15from omegaconf import DictConfig, OmegaConf 

16 

17T = TypeVar("T") 

18 

19 

20class ConfigError(Exception): 

21 """Configuration-related errors. 

22 

23 Raised when config files are missing, overrides are invalid, etc. 

24 """ 

25 

26 

27def option( 

28 default: Any = MISSING, 

29 *, 

30 help: str = "", 

31 **kwargs: Any, 

32) -> Any: 

33 """Dataclass field wrapper with help text for CLI. 

34 

35 Use this instead of `field()` to add help text that appears in CLI help. 

36 

37 Args: 

38 default: Default value for the field. 

39 help: Help text shown in CLI. 

40 **kwargs: Additional arguments passed to `dataclasses.field()`. 

41 

42 Returns: 

43 A dataclass field with metadata. 

44 

45 Examples: 

46 >>> from dataclasses import dataclass 

47 >>> @dataclass 

48 ... class Config: 

49 ... epochs: int = option(100, help="Number of epochs") 

50 ... lr: float = option(0.001, help="Learning rate") 

51 >>> cfg = Config() 

52 >>> cfg.epochs 

53 100 

54 """ 

55 metadata = kwargs.pop("metadata", {}) 

56 metadata["help"] = help 

57 return field(default=default, metadata=metadata, **kwargs) 

58 

59 

60def load_yaml(path: str | Path) -> DictConfig: 

61 """Load a YAML file into a DictConfig. 

62 

63 Args: 

64 path: Path to the YAML file. 

65 

66 Returns: 

67 DictConfig containing the parsed YAML. 

68 

69 Raises: 

70 ConfigError: If the file does not exist. 

71 

72 Examples: 

73 >>> import tempfile 

74 >>> from pathlib import Path 

75 >>> with tempfile.NamedTemporaryFile(suffix=".yml", delete=False, mode="w") as f: 

76 ... _ = f.write("name: test\\ncount: 42") 

77 ... path = f.name 

78 >>> cfg = load_yaml(path) 

79 >>> cfg.name 

80 'test' 

81 >>> Path(path).unlink() 

82 """ 

83 path = Path(path) 

84 if not path.exists(): 

85 raise ConfigError(f"Config file not found: {path}") 

86 return OmegaConf.load(path) # type: ignore[return-value] 

87 

88 

89def parse_overrides(overrides: list[str]) -> DictConfig: 

90 """Parse CLI overrides into a config tree. 

91 

92 Supports three types of overrides: 

93 - `key=value` - Scalar override 

94 - `key.path=value` - Nested override 

95 - `key=@file.yml` - Subtree replacement from file 

96 

97 Args: 

98 overrides: List of override strings. 

99 

100 Returns: 

101 DictConfig with parsed overrides. 

102 

103 Raises: 

104 ConfigError: If an override doesn't contain '='. 

105 

106 Examples: 

107 >>> cfg = parse_overrides(["name=test", "count=42"]) 

108 >>> cfg.name 

109 'test' 

110 >>> cfg.count 

111 42 

112 >>> cfg = parse_overrides(["model.layers=24"]) 

113 >>> cfg.model.layers 

114 24 

115 """ 

116 result: dict[str, Any] = {} 

117 

118 for override in overrides: 

119 if "=" not in override: 

120 raise ConfigError(f"Invalid override: '{override}'. Expected 'key=value' format.") 

121 

122 key, value = override.split("=", 1) 

123 key = key.strip() 

124 value = value.strip() 

125 

126 # Handle @file syntax for subtree replacement 

127 if value.startswith("@"): 127 ↛ 128line 127 didn't jump to line 128 because the condition on line 127 was never true

128 file_path = value[1:] 

129 parsed = OmegaConf.to_container(load_yaml(file_path)) 

130 else: 

131 parsed = _parse_value(value) 

132 

133 # Build nested dict from dot notation 

134 _set_nested(result, key.split("."), parsed) 

135 

136 return OmegaConf.create(result) 

137 

138 

139def _parse_value(value: str) -> Any: 

140 """Parse a string value into Python type. 

141 

142 Args: 

143 value: String to parse. 

144 

145 Returns: 

146 Parsed Python value (bool, None, int, float, list, or str). 

147 

148 Examples: 

149 >>> _parse_value("true") 

150 True 

151 >>> _parse_value("42") 

152 42 

153 >>> _parse_value("3.14") 

154 3.14 

155 >>> _parse_value("[1, 2, 3]") 

156 [1, 2, 3] 

157 """ 

158 # Boolean 

159 if value.lower() == "true": 

160 return True 

161 if value.lower() == "false": 

162 return False 

163 

164 # None 

165 if value.lower() in ("none", "null"): 

166 return None 

167 

168 # List 

169 if value.startswith("[") and value.endswith("]"): 

170 inner = value[1:-1].strip() 

171 if not inner: 

172 return [] 

173 return [_parse_value(v.strip()) for v in inner.split(",")] 

174 

175 # Integer 

176 try: 

177 return int(value) 

178 except ValueError: 

179 pass 

180 

181 # Float 

182 try: 

183 return float(value) 

184 except ValueError: 

185 pass 

186 

187 # Quoted string 

188 if (value.startswith('"') and value.endswith('"')) or ( 188 ↛ 191line 188 didn't jump to line 191 because the condition on line 188 was never true

189 value.startswith("'") and value.endswith("'") 

190 ): 

191 return value[1:-1] 

192 

193 return value 

194 

195 

196def _set_nested(d: dict[str, Any], keys: list[str], value: Any) -> None: 

197 """Set a value in a nested dict using a list of keys. 

198 

199 Args: 

200 d: Dictionary to modify. 

201 keys: List of keys forming the path. 

202 value: Value to set. 

203 

204 Examples: 

205 >>> d = {} 

206 >>> _set_nested(d, ["a", "b", "c"], 42) 

207 >>> d 

208 {'a': {'b': {'c': 42}}} 

209 """ 

210 for key in keys[:-1]: 

211 d = d.setdefault(key, {}) 

212 d[keys[-1]] = value 

213 

214 

215def compile_config( 

216 base: DictConfig | None = None, 

217 overrides: list[str] | None = None, 

218 schema: type[T] | None = None, 

219) -> DictConfig | T: 

220 """Compile a config from base + overrides. 

221 

222 This is the core function: pure tree rewrite. 

223 Priority: schema defaults < base < overrides 

224 

225 Args: 

226 base: Base config tree (from YAML). 

227 overrides: List of override strings (`key=value`, `key=@file`). 

228 schema: Optional dataclass for type validation. 

229 

230 Returns: 

231 Compiled config. Typed if schema provided, else DictConfig. 

232 

233 Examples: 

234 >>> from dataclasses import dataclass 

235 >>> @dataclass 

236 ... class Config: 

237 ... name: str = "default" 

238 ... count: int = 1 

239 >>> cfg = compile_config(schema=Config) 

240 >>> cfg.name 

241 'default' 

242 >>> cfg = compile_config(schema=Config, overrides=["name=custom"]) 

243 >>> cfg.name 

244 'custom' 

245 """ 

246 # Build config: schema defaults -> base -> overrides 

247 if schema is not None: 

248 cfg = OmegaConf.structured(schema) 

249 if base is not None: 

250 cfg = OmegaConf.merge(cfg, base) 

251 else: 

252 cfg = base if base is not None else OmegaConf.create({}) 

253 

254 # Apply overrides (tree rewrite) 

255 if overrides: 

256 override_cfg = parse_overrides(overrides) 

257 try: 

258 cfg = OmegaConf.merge(cfg, override_cfg) 

259 except Exception as e: 

260 # Extract the key from OmegaConf error message 

261 error_msg = str(e) 

262 if "Key" in error_msg and "not in" in error_msg: 262 ↛ 273line 262 didn't jump to line 273 because the condition on line 262 was always true

263 # Parse: Key 'typer' not in 'ModelConfig' 

264 import re 

265 

266 match = re.search(r"Key '(\w+)' not in '(\w+)'", error_msg) 

267 if match: 267 ↛ 273line 267 didn't jump to line 273 because the condition on line 267 was always true

268 key, cls = match.groups() 

269 raise ConfigError( 

270 f"Invalid config key '{key}' in '{cls}'. Check for typos in your overrides." 

271 ) from None 

272 # Re-raise with friendlier message 

273 raise ConfigError(f"Config error: {error_msg}") from None 

274 

275 # Convert to typed object if schema provided 

276 if schema is not None: 

277 return OmegaConf.to_object(cfg) # type: ignore[return-value] 

278 

279 return cfg # type: ignore[no-any-return] 

280 

281 

282def to_yaml(config: Any) -> str: 

283 """Convert config to YAML string. 

284 

285 Args: 

286 config: Config object (dataclass, dict, or DictConfig). 

287 

288 Returns: 

289 YAML string representation. 

290 

291 Examples: 

292 >>> from dataclasses import dataclass 

293 >>> @dataclass 

294 ... class Config: 

295 ... name: str = "test" 

296 >>> yaml_str = to_yaml(Config()) 

297 >>> "name: test" in yaml_str 

298 True 

299 """ 

300 if is_dataclass(config) and not isinstance(config, type): 

301 cfg = OmegaConf.structured(config) 

302 elif isinstance(config, DictConfig): 302 ↛ 303line 302 didn't jump to line 303 because the condition on line 302 was never true

303 cfg = config 

304 else: 

305 cfg = OmegaConf.create(config) 

306 

307 return OmegaConf.to_yaml(cfg) 

308 

309 

310def save_yaml(config: Any, path: str | Path) -> None: 

311 """Save config to YAML file. 

312 

313 Args: 

314 config: Config object to save. 

315 path: Path to write the YAML file. 

316 

317 Examples: 

318 >>> import tempfile 

319 >>> from dataclasses import dataclass 

320 >>> @dataclass 

321 ... class Config: 

322 ... name: str = "test" 

323 >>> with tempfile.NamedTemporaryFile(suffix=".yml", delete=False) as f: 

324 ... save_yaml(Config(), f.name) 

325 ... content = open(f.name).read() 

326 >>> "name: test" in content 

327 True 

328 """ 

329 Path(path).write_text(to_yaml(config)) 

330 

331 

332# --- Schema introspection for CLI help --- 

333 

334 

335def get_flattened_config_options( 

336 schema: type, 

337 prefix: str = "", 

338) -> list[tuple[str, str, Any, str]]: 

339 """Recursively flatten nested dataclasses into dotted paths. 

340 

341 Args: 

342 schema: Dataclass type to introspect. 

343 prefix: Current path prefix (for recursion). 

344 

345 Returns: 

346 List of tuples: (dotted_name, type_name, default, help_text). 

347 

348 Examples: 

349 >>> from dataclasses import dataclass 

350 >>> @dataclass 

351 ... class Config: 

352 ... name: str = option("default", help="Name field") 

353 >>> opts = get_flattened_config_options(Config) 

354 >>> opts[0][0] 

355 'name' 

356 """ 

357 if not is_dataclass(schema): 357 ↛ 358line 357 didn't jump to line 358 because the condition on line 357 was never true

358 return [] 

359 

360 result = [] 

361 

362 for f in fields(schema): 

363 name = f.name 

364 full_name = f"{prefix}.{name}" if prefix else name 

365 help_text = f.metadata.get("help", "") 

366 field_type = f.type 

367 

368 if is_dataclass(field_type): 368 ↛ 370line 368 didn't jump to line 370 because the condition on line 368 was never true

369 # Get help from nested dataclass docstring if not provided 

370 if not help_text and field_type.__doc__: 

371 help_text = field_type.__doc__.strip().split("\n")[0] 

372 # Recurse into nested dataclass 

373 result.extend(get_flattened_config_options(field_type, prefix=full_name)) # type: ignore[arg-type] 

374 else: 

375 default = f.default if f.default is not MISSING else None 

376 type_name = getattr(field_type, "__name__", str(field_type)) 

377 result.append((full_name, type_name, default, help_text)) 

378 

379 return result 

380 

381 

382def get_schema_structure(schema: type) -> tuple[dict[str, type], list[tuple[str, str, Any, str]]]: 

383 """Inspect schema to get flattened config options. 

384 

385 Args: 

386 schema: Dataclass type to introspect. 

387 

388 Returns: 

389 Tuple of (subcommands, config_options): 

390 - subcommands: Always empty dict (for backward compatibility) 

391 - config_options: List of (dotted_name, type_name, default, help) 

392 """ 

393 return {}, get_flattened_config_options(schema)