Coverage for src / pipeline / context_pressure_handler.py: 100%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-04 04:43 +0000

1"""ContextPressureHandler: Context pressure checkpoint and restart handling. 

2 

3Extracted from AgentSessionRunner to separate checkpoint/restart logic from 

4session lifecycle management. This module handles: 

5- Fetching checkpoints from agents via SDK client 

6- Building continuation prompts with checkpoint context 

7- Managing restart loop state (continuation count, prompts) 

8 

9Design principles: 

10- Protocol-based SDK client for testability 

11- Explicit error type for pressure threshold detection 

12- Timeout handling for checkpoint fetch operations 

13""" 

14 

15from __future__ import annotations 

16 

17import asyncio 

18import logging 

19from dataclasses import dataclass 

20from typing import ( 

21 TYPE_CHECKING, 

22 Any, 

23 cast, 

24) 

25 

26from src.domain.prompts import ( 

27 build_continuation_prompt, 

28 extract_checkpoint, 

29) 

30from src.pipeline.message_stream_processor import ContextPressureError 

31 

32if TYPE_CHECKING: 

33 from src.core.protocols import ( 

34 SDKClientFactoryProtocol, 

35 ) 

36 

37 

38logger = logging.getLogger(__name__) 

39 

40# Default timeout for checkpoint fetch operations (30 seconds) 

41DEFAULT_CHECKPOINT_TIMEOUT_SECONDS = 30 

42 

43 

44@dataclass 

45class ContextPressureConfig: 

46 """Configuration for context pressure handling. 

47 

48 Attributes: 

49 checkpoint_request_prompt: Prompt template to request checkpoint from agent. 

50 continuation_template: Template for continuation prompt with checkpoint. 

51 checkpoint_timeout_seconds: Timeout for checkpoint fetch operations. 

52 """ 

53 

54 checkpoint_request_prompt: str 

55 continuation_template: str 

56 checkpoint_timeout_seconds: float = DEFAULT_CHECKPOINT_TIMEOUT_SECONDS 

57 

58 

59@dataclass 

60class CheckpointResult: 

61 """Result of fetching a checkpoint from an agent. 

62 

63 Attributes: 

64 checkpoint: Extracted checkpoint text (may be empty). 

65 timed_out: Whether the fetch timed out. 

66 """ 

67 

68 checkpoint: str 

69 timed_out: bool = False 

70 

71 

72class ContextPressureHandler: 

73 """Handles context pressure detection and checkpoint/restart logic. 

74 

75 This handler encapsulates: 

76 - Fetching checkpoints from agents before restart 

77 - Building continuation prompts with checkpoint context 

78 - Managing restart loop state 

79 

80 The handler is stateless per-call; restart state is managed by the caller. 

81 """ 

82 

83 def __init__( 

84 self, 

85 config: ContextPressureConfig, 

86 sdk_client_factory: SDKClientFactoryProtocol, 

87 ) -> None: 

88 """Initialize the handler. 

89 

90 Args: 

91 config: Context pressure configuration. 

92 sdk_client_factory: Factory for creating SDK clients. 

93 """ 

94 self.config = config 

95 self.sdk_client_factory = sdk_client_factory 

96 

97 async def fetch_checkpoint( 

98 self, 

99 session_id: str, 

100 issue_id: str, 

101 options: object, 

102 timeout_seconds: float | None = None, 

103 ) -> CheckpointResult: 

104 """Fetch checkpoint from agent before context restart. 

105 

106 Sends checkpoint_request_prompt to the current session and extracts 

107 the checkpoint block from the response. 

108 

109 Args: 

110 session_id: SDK session ID from ContextPressureError. 

111 issue_id: Issue ID for logging. 

112 options: SDK client options. 

113 timeout_seconds: Optional override for checkpoint timeout. 

114 

115 Returns: 

116 CheckpointResult with extracted checkpoint text. 

117 """ 

118 effective_timeout = ( 

119 timeout_seconds 

120 if timeout_seconds is not None 

121 else self.config.checkpoint_timeout_seconds 

122 ) 

123 

124 logger.info( 

125 "Session %s: requesting checkpoint from session %s...", 

126 issue_id, 

127 session_id[:8] if session_id else "unknown", 

128 ) 

129 

130 checkpoint_prompt = self.config.checkpoint_request_prompt 

131 if not checkpoint_prompt: 

132 logger.warning( 

133 "Session %s: no checkpoint_request prompt configured, using empty checkpoint", 

134 issue_id, 

135 ) 

136 return CheckpointResult(checkpoint="", timed_out=False) 

137 

138 # Create client to query for checkpoint 

139 client = self.sdk_client_factory.create(options) 

140 

141 response_text = "" 

142 try: 

143 async with asyncio.timeout(effective_timeout): 

144 async with client: 

145 await client.query(checkpoint_prompt, session_id=session_id) 

146 async for message in client.receive_response(): 

147 # Extract text from AssistantMessage 

148 content = getattr(message, "content", None) 

149 if content is not None: 

150 for block in cast("list[Any]", content): 

151 text = getattr(block, "text", None) 

152 if text is not None: 

153 response_text += text 

154 except TimeoutError: 

155 logger.warning( 

156 "Session %s: checkpoint fetch timed out after %.1fs, using empty checkpoint", 

157 issue_id, 

158 effective_timeout, 

159 ) 

160 return CheckpointResult(checkpoint="", timed_out=True) 

161 except Exception as e: 

162 logger.warning( 

163 "Session %s: checkpoint query failed: %s, using empty checkpoint", 

164 issue_id, 

165 e, 

166 ) 

167 return CheckpointResult(checkpoint="", timed_out=False) 

168 

169 # Extract checkpoint from response 

170 checkpoint = extract_checkpoint(response_text) 

171 logger.debug( 

172 "Session %s: extracted checkpoint (%d chars)", 

173 issue_id, 

174 len(checkpoint), 

175 ) 

176 return CheckpointResult(checkpoint=checkpoint, timed_out=False) 

177 

178 def build_continuation_prompt(self, checkpoint: str) -> str: 

179 """Build continuation prompt with checkpoint context. 

180 

181 Args: 

182 checkpoint: Extracted checkpoint text from previous session. 

183 

184 Returns: 

185 Formatted continuation prompt for restart. 

186 """ 

187 continuation_template = self.config.continuation_template 

188 if continuation_template: 

189 return build_continuation_prompt(continuation_template, checkpoint) 

190 # Fallback: just use checkpoint as prompt 

191 return f"Continue from checkpoint:\n\n{checkpoint}" 

192 

193 async def handle_pressure_error( 

194 self, 

195 error: ContextPressureError, 

196 issue_id: str, 

197 options: object, 

198 continuation_count: int, 

199 remaining_time: float, 

200 ) -> tuple[str, int]: 

201 """Handle a ContextPressureError by fetching checkpoint and building continuation. 

202 

203 This is the main entry point for handling context pressure. It: 

204 1. Calculates effective timeout based on remaining session time 

205 2. Fetches checkpoint from the current session 

206 3. Builds the continuation prompt 

207 4. Logs restart information 

208 

209 Args: 

210 error: The ContextPressureError that triggered the restart. 

211 issue_id: Issue ID for logging. 

212 options: SDK client options. 

213 continuation_count: Current restart count (will be incremented). 

214 remaining_time: Remaining session time in seconds. 

215 

216 Returns: 

217 Tuple of (continuation_prompt, new_continuation_count). 

218 """ 

219 # Calculate effective timeout bounded by remaining session time 

220 effective_timeout = max( 

221 0, min(remaining_time, self.config.checkpoint_timeout_seconds) 

222 ) 

223 

224 # Fetch checkpoint from current session 

225 result = await self.fetch_checkpoint( 

226 session_id=error.session_id, 

227 issue_id=issue_id, 

228 options=options, 

229 timeout_seconds=effective_timeout, 

230 ) 

231 

232 new_count = continuation_count + 1 

233 logger.info( 

234 "Session %s: context restart #%d at %.1f%%", 

235 issue_id, 

236 new_count, 

237 error.pressure_ratio * 100, 

238 ) 

239 

240 # Build continuation prompt 

241 continuation_prompt = self.build_continuation_prompt(result.checkpoint) 

242 

243 return continuation_prompt, new_count 

244 

245 

246# Re-export ContextPressureError for convenience 

247__all__ = [ 

248 "DEFAULT_CHECKPOINT_TIMEOUT_SECONDS", 

249 "CheckpointResult", 

250 "ContextPressureConfig", 

251 "ContextPressureError", 

252 "ContextPressureHandler", 

253]