Coverage for src / domain / prompts.py: 64%
66 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-04 04:43 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-04 04:43 +0000
1"""Shared prompt loading utilities.
3This module centralizes prompt file loading to avoid duplication across modules.
4"""
6from __future__ import annotations
8import re
9from dataclasses import dataclass
10from pathlib import Path
11from typing import TYPE_CHECKING
13if TYPE_CHECKING:
14 from src.domain.validation.config import PromptValidationCommands
17@dataclass(frozen=True)
18class PromptProvider:
19 """Data class holding all loaded prompt templates.
21 This is a pure data object constructed at startup boundary.
22 All fields are immutable string contents of prompt files.
23 """
25 implementer_prompt: str
26 review_followup_prompt: str
27 gate_followup_prompt: str
28 fixer_prompt: str
29 idle_resume_prompt: str
30 checkpoint_request_prompt: str
31 continuation_prompt: str
34def load_prompts(prompt_dir: Path) -> PromptProvider:
35 """Load all prompt templates from disk.
37 Args:
38 prompt_dir: Directory containing prompt template files.
40 Returns:
41 PromptProvider with all loaded prompt templates.
43 Raises:
44 FileNotFoundError: If any required prompt file is missing.
45 """
46 return PromptProvider(
47 implementer_prompt=(prompt_dir / "implementer_prompt.md").read_text(),
48 review_followup_prompt=(prompt_dir / "review_followup.md").read_text(),
49 gate_followup_prompt=(prompt_dir / "gate_followup.md").read_text(),
50 fixer_prompt=(prompt_dir / "fixer.md").read_text(),
51 idle_resume_prompt=(prompt_dir / "idle_resume.md").read_text(),
52 checkpoint_request_prompt=(prompt_dir / "checkpoint_request.md").read_text(),
53 continuation_prompt=(prompt_dir / "continuation.md").read_text(),
54 )
57def format_implementer_prompt(
58 implementer_prompt: str,
59 issue_id: str,
60 repo_path: Path,
61 agent_id: str,
62 validation_commands: PromptValidationCommands,
63 lock_dir: Path,
64 scripts_dir: Path,
65) -> str:
66 """Format the implementer prompt with runtime values.
68 Args:
69 implementer_prompt: The raw implementer prompt template.
70 issue_id: The issue ID being implemented.
71 repo_path: Path to the repository.
72 agent_id: The agent ID for this session.
73 validation_commands: Validation commands for the prompt.
74 lock_dir: Directory for lock files (from infra layer).
75 scripts_dir: Directory containing helper scripts (from infra layer).
77 Returns:
78 Formatted prompt string.
79 """
80 return implementer_prompt.format(
81 issue_id=issue_id,
82 repo_path=repo_path,
83 lock_dir=lock_dir,
84 scripts_dir=scripts_dir,
85 agent_id=agent_id,
86 lint_command=validation_commands.lint,
87 format_command=validation_commands.format,
88 typecheck_command=validation_commands.typecheck,
89 test_command=validation_commands.test,
90 )
93def get_default_validation_commands() -> PromptValidationCommands:
94 """Return default Python/uv validation commands with cache isolation.
96 These defaults are used when no mala.yaml configuration is found.
97 Commands include cache isolation flags for parallel agent runs,
98 using $AGENT_ID environment variable (set in the agent environment).
100 Returns:
101 PromptValidationCommands with default Python/uv toolchain commands.
102 """
103 from src.domain.validation.config import PromptValidationCommands
105 return PromptValidationCommands(
106 lint="RUFF_CACHE_DIR=/tmp/ruff-${AGENT_ID:-default} uvx ruff check .",
107 format="RUFF_CACHE_DIR=/tmp/ruff-${AGENT_ID:-default} uvx ruff format .",
108 typecheck="uvx ty check",
109 test="uv run pytest -o cache_dir=/tmp/pytest-${AGENT_ID:-default}",
110 )
113def _default_prompt_dir() -> Path:
114 """Return the default prompts directory."""
115 return Path(__file__).parent.parent / "prompts"
118def load_prompt(name: str, prompt_dir: Path | None = None) -> str:
119 """Load a single prompt template by name.
121 Args:
122 name: Name of the prompt (without .md extension).
123 prompt_dir: Directory containing prompt files. Defaults to src/prompts.
125 Returns:
126 The prompt template content.
128 Raises:
129 FileNotFoundError: If the prompt file doesn't exist.
130 """
131 if prompt_dir is None:
132 prompt_dir = _default_prompt_dir()
133 return (prompt_dir / f"{name}.md").read_text()
136def extract_checkpoint(text: str) -> str:
137 """Extract checkpoint block from agent response text.
139 Looks for content between <checkpoint> and </checkpoint> tags.
140 For nested tags, returns the outermost checkpoint content.
141 Returns full text as fallback if no tags found (stripping code block wrappers).
143 Args:
144 text: Raw agent response text.
146 Returns:
147 Extracted checkpoint content, or full text if no tags found.
148 """
149 # Find first opening tag
150 start_match = re.search(r"<checkpoint>", text)
151 if not start_match:
152 # Fallback: strip code block wrappers and return
153 stripped = re.sub(r"\A\s*```\w*[ \t]*\n?", "", text)
154 stripped = re.sub(r"\n?[ \t]*```[ \t]*\Z", "", stripped)
155 return stripped
157 # Track nesting depth to find matching closing tag
158 start_pos = start_match.end()
159 depth = 1
160 pos = start_pos
162 while depth > 0 and pos < len(text):
163 next_open = text.find("<checkpoint>", pos)
164 next_close = text.find("</checkpoint>", pos)
166 if next_close == -1:
167 # No more closing tags, return from start to end
168 break
170 if next_open != -1 and next_open < next_close:
171 # Found nested opening tag
172 depth += 1
173 pos = next_open + len("<checkpoint>")
174 else:
175 # Found closing tag
176 depth -= 1
177 if depth == 0:
178 return text[start_pos:next_close]
179 pos = next_close + len("</checkpoint>")
181 # No proper closing found, return from start to end
182 return text[start_pos:]
185def build_continuation_prompt(continuation_template: str, checkpoint_text: str) -> str:
186 """Build a continuation prompt with checkpoint context.
188 Args:
189 continuation_template: The continuation prompt template from PromptProvider.
190 checkpoint_text: The checkpoint block from the previous session.
192 Returns:
193 Formatted continuation prompt with checkpoint embedded.
194 """
195 # Use str.replace instead of str.format to avoid KeyError if checkpoint
196 # contains curly braces (e.g., JSON or code snippets)
197 return continuation_template.replace("{checkpoint}", checkpoint_text)
200def build_prompt_validation_commands(repo_path: Path) -> PromptValidationCommands:
201 """Build PromptValidationCommands for a repository.
203 Loads the mala.yaml configuration, merges with preset if specified,
204 and returns the validation commands formatted for prompt templates.
206 Args:
207 repo_path: Path to the repository root directory.
209 Returns:
210 PromptValidationCommands with command strings for prompt templates.
211 Returns default Python/uv commands if no config is found.
212 """
213 from src.domain.validation.config import PromptValidationCommands
214 from src.domain.validation.config_loader import ConfigMissingError, load_config
215 from src.domain.validation.config_merger import merge_configs
216 from src.domain.validation.preset_registry import PresetRegistry
218 try:
219 user_config = load_config(repo_path)
220 except ConfigMissingError:
221 # No config file - return defaults
222 return get_default_validation_commands()
224 # Load and merge preset if specified
225 if user_config.preset is not None:
226 registry = PresetRegistry()
227 preset_config = registry.get(user_config.preset)
228 merged_config = merge_configs(preset_config, user_config)
229 else:
230 merged_config = user_config
232 return PromptValidationCommands.from_validation_config(merged_config)