Coverage for src / domain / prompts.py: 64%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-04 04:43 +0000

1"""Shared prompt loading utilities. 

2 

3This module centralizes prompt file loading to avoid duplication across modules. 

4""" 

5 

6from __future__ import annotations 

7 

8import re 

9from dataclasses import dataclass 

10from pathlib import Path 

11from typing import TYPE_CHECKING 

12 

13if TYPE_CHECKING: 

14 from src.domain.validation.config import PromptValidationCommands 

15 

16 

17@dataclass(frozen=True) 

18class PromptProvider: 

19 """Data class holding all loaded prompt templates. 

20 

21 This is a pure data object constructed at startup boundary. 

22 All fields are immutable string contents of prompt files. 

23 """ 

24 

25 implementer_prompt: str 

26 review_followup_prompt: str 

27 gate_followup_prompt: str 

28 fixer_prompt: str 

29 idle_resume_prompt: str 

30 checkpoint_request_prompt: str 

31 continuation_prompt: str 

32 

33 

34def load_prompts(prompt_dir: Path) -> PromptProvider: 

35 """Load all prompt templates from disk. 

36 

37 Args: 

38 prompt_dir: Directory containing prompt template files. 

39 

40 Returns: 

41 PromptProvider with all loaded prompt templates. 

42 

43 Raises: 

44 FileNotFoundError: If any required prompt file is missing. 

45 """ 

46 return PromptProvider( 

47 implementer_prompt=(prompt_dir / "implementer_prompt.md").read_text(), 

48 review_followup_prompt=(prompt_dir / "review_followup.md").read_text(), 

49 gate_followup_prompt=(prompt_dir / "gate_followup.md").read_text(), 

50 fixer_prompt=(prompt_dir / "fixer.md").read_text(), 

51 idle_resume_prompt=(prompt_dir / "idle_resume.md").read_text(), 

52 checkpoint_request_prompt=(prompt_dir / "checkpoint_request.md").read_text(), 

53 continuation_prompt=(prompt_dir / "continuation.md").read_text(), 

54 ) 

55 

56 

57def format_implementer_prompt( 

58 implementer_prompt: str, 

59 issue_id: str, 

60 repo_path: Path, 

61 agent_id: str, 

62 validation_commands: PromptValidationCommands, 

63 lock_dir: Path, 

64 scripts_dir: Path, 

65) -> str: 

66 """Format the implementer prompt with runtime values. 

67 

68 Args: 

69 implementer_prompt: The raw implementer prompt template. 

70 issue_id: The issue ID being implemented. 

71 repo_path: Path to the repository. 

72 agent_id: The agent ID for this session. 

73 validation_commands: Validation commands for the prompt. 

74 lock_dir: Directory for lock files (from infra layer). 

75 scripts_dir: Directory containing helper scripts (from infra layer). 

76 

77 Returns: 

78 Formatted prompt string. 

79 """ 

80 return implementer_prompt.format( 

81 issue_id=issue_id, 

82 repo_path=repo_path, 

83 lock_dir=lock_dir, 

84 scripts_dir=scripts_dir, 

85 agent_id=agent_id, 

86 lint_command=validation_commands.lint, 

87 format_command=validation_commands.format, 

88 typecheck_command=validation_commands.typecheck, 

89 test_command=validation_commands.test, 

90 ) 

91 

92 

93def get_default_validation_commands() -> PromptValidationCommands: 

94 """Return default Python/uv validation commands with cache isolation. 

95 

96 These defaults are used when no mala.yaml configuration is found. 

97 Commands include cache isolation flags for parallel agent runs, 

98 using $AGENT_ID environment variable (set in the agent environment). 

99 

100 Returns: 

101 PromptValidationCommands with default Python/uv toolchain commands. 

102 """ 

103 from src.domain.validation.config import PromptValidationCommands 

104 

105 return PromptValidationCommands( 

106 lint="RUFF_CACHE_DIR=/tmp/ruff-${AGENT_ID:-default} uvx ruff check .", 

107 format="RUFF_CACHE_DIR=/tmp/ruff-${AGENT_ID:-default} uvx ruff format .", 

108 typecheck="uvx ty check", 

109 test="uv run pytest -o cache_dir=/tmp/pytest-${AGENT_ID:-default}", 

110 ) 

111 

112 

113def _default_prompt_dir() -> Path: 

114 """Return the default prompts directory.""" 

115 return Path(__file__).parent.parent / "prompts" 

116 

117 

118def load_prompt(name: str, prompt_dir: Path | None = None) -> str: 

119 """Load a single prompt template by name. 

120 

121 Args: 

122 name: Name of the prompt (without .md extension). 

123 prompt_dir: Directory containing prompt files. Defaults to src/prompts. 

124 

125 Returns: 

126 The prompt template content. 

127 

128 Raises: 

129 FileNotFoundError: If the prompt file doesn't exist. 

130 """ 

131 if prompt_dir is None: 

132 prompt_dir = _default_prompt_dir() 

133 return (prompt_dir / f"{name}.md").read_text() 

134 

135 

136def extract_checkpoint(text: str) -> str: 

137 """Extract checkpoint block from agent response text. 

138 

139 Looks for content between <checkpoint> and </checkpoint> tags. 

140 For nested tags, returns the outermost checkpoint content. 

141 Returns full text as fallback if no tags found (stripping code block wrappers). 

142 

143 Args: 

144 text: Raw agent response text. 

145 

146 Returns: 

147 Extracted checkpoint content, or full text if no tags found. 

148 """ 

149 # Find first opening tag 

150 start_match = re.search(r"<checkpoint>", text) 

151 if not start_match: 

152 # Fallback: strip code block wrappers and return 

153 stripped = re.sub(r"\A\s*```\w*[ \t]*\n?", "", text) 

154 stripped = re.sub(r"\n?[ \t]*```[ \t]*\Z", "", stripped) 

155 return stripped 

156 

157 # Track nesting depth to find matching closing tag 

158 start_pos = start_match.end() 

159 depth = 1 

160 pos = start_pos 

161 

162 while depth > 0 and pos < len(text): 

163 next_open = text.find("<checkpoint>", pos) 

164 next_close = text.find("</checkpoint>", pos) 

165 

166 if next_close == -1: 

167 # No more closing tags, return from start to end 

168 break 

169 

170 if next_open != -1 and next_open < next_close: 

171 # Found nested opening tag 

172 depth += 1 

173 pos = next_open + len("<checkpoint>") 

174 else: 

175 # Found closing tag 

176 depth -= 1 

177 if depth == 0: 

178 return text[start_pos:next_close] 

179 pos = next_close + len("</checkpoint>") 

180 

181 # No proper closing found, return from start to end 

182 return text[start_pos:] 

183 

184 

185def build_continuation_prompt(continuation_template: str, checkpoint_text: str) -> str: 

186 """Build a continuation prompt with checkpoint context. 

187 

188 Args: 

189 continuation_template: The continuation prompt template from PromptProvider. 

190 checkpoint_text: The checkpoint block from the previous session. 

191 

192 Returns: 

193 Formatted continuation prompt with checkpoint embedded. 

194 """ 

195 # Use str.replace instead of str.format to avoid KeyError if checkpoint 

196 # contains curly braces (e.g., JSON or code snippets) 

197 return continuation_template.replace("{checkpoint}", checkpoint_text) 

198 

199 

200def build_prompt_validation_commands(repo_path: Path) -> PromptValidationCommands: 

201 """Build PromptValidationCommands for a repository. 

202 

203 Loads the mala.yaml configuration, merges with preset if specified, 

204 and returns the validation commands formatted for prompt templates. 

205 

206 Args: 

207 repo_path: Path to the repository root directory. 

208 

209 Returns: 

210 PromptValidationCommands with command strings for prompt templates. 

211 Returns default Python/uv commands if no config is found. 

212 """ 

213 from src.domain.validation.config import PromptValidationCommands 

214 from src.domain.validation.config_loader import ConfigMissingError, load_config 

215 from src.domain.validation.config_merger import merge_configs 

216 from src.domain.validation.preset_registry import PresetRegistry 

217 

218 try: 

219 user_config = load_config(repo_path) 

220 except ConfigMissingError: 

221 # No config file - return defaults 

222 return get_default_validation_commands() 

223 

224 # Load and merge preset if specified 

225 if user_config.preset is not None: 

226 registry = PresetRegistry() 

227 preset_config = registry.get(user_config.preset) 

228 merged_config = merge_configs(preset_config, user_config) 

229 else: 

230 merged_config = user_config 

231 

232 return PromptValidationCommands.from_validation_config(merged_config)