Coverage for src / pipeline / message_stream_processor.py: 87%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-04 04:43 +0000

1"""MessageStreamProcessor: SDK stream processing component. 

2 

3Extracted from AgentSessionRunner to separate stream iteration logic from 

4session lifecycle management. This module handles: 

5- Wrapping SDK streams with idle timeout detection 

6- Iterating and processing SDK messages (AssistantMessage, ResultMessage) 

7- Tracking tool calls and lint cache updates 

8- Context pressure detection from usage data 

9 

10Design principles: 

11- Protocol-based message/block checks for testability (no SDK imports at runtime) 

12- Explicit state management via MessageIterationState 

13- Callbacks for external operations (text/tool notifications) 

14""" 

15 

16from __future__ import annotations 

17 

18import asyncio 

19import logging 

20import time 

21from collections.abc import Callable 

22from dataclasses import dataclass, field 

23from typing import ( 

24 TYPE_CHECKING, 

25 Any, 

26 Generic, 

27 Protocol, 

28 TypeVar, 

29) 

30 

31if TYPE_CHECKING: 

32 from collections.abc import AsyncIterator 

33 from typing import Self 

34 

35 from src.domain.lifecycle import LifecycleContext 

36 from src.infra.telemetry import TelemetrySpan 

37 

38 

39class LintCacheProtocol(Protocol): 

40 """Protocol for lint cache operations used by stream processor.""" 

41 

42 def detect_lint_command(self, command: str) -> str | None: 

43 """Detect if command is a lint command and return lint type.""" 

44 ... 

45 

46 def mark_success(self, lint_type: str, command: str) -> None: 

47 """Mark a lint command as successful.""" 

48 ... 

49 

50 

51logger = logging.getLogger(__name__) 

52 

53_T = TypeVar("_T") 

54 

55 

56class IdleTimeoutError(Exception): 

57 """Raised when the SDK response stream is idle for too long.""" 

58 

59 

60class ContextPressureError(Exception): 

61 """Raised when context usage exceeds the restart threshold. 

62 

63 This exception signals that the agent session should be checkpointed 

64 and restarted with a fresh context to avoid context exhaustion. 

65 

66 Attributes: 

67 session_id: SDK session ID for checkpoint query. 

68 input_tokens: Current input token count. 

69 output_tokens: Current output token count. 

70 cache_read_tokens: Current cache read token count. 

71 pressure_ratio: Ratio of usage to limit (e.g., 0.92 = 92%). 

72 """ 

73 

74 def __init__( 

75 self, 

76 session_id: str, 

77 input_tokens: int, 

78 output_tokens: int, 

79 cache_read_tokens: int, 

80 pressure_ratio: float, 

81 ) -> None: 

82 self.session_id = session_id 

83 self.input_tokens = input_tokens 

84 self.output_tokens = output_tokens 

85 self.cache_read_tokens = cache_read_tokens 

86 self.pressure_ratio = pressure_ratio 

87 super().__init__( 

88 f"Context pressure {pressure_ratio:.1%} exceeds threshold " 

89 f"(input={input_tokens}, output={output_tokens}, session={session_id})" 

90 ) 

91 

92 

93class IdleTimeoutStream(Generic[_T]): 

94 """Wrap an async iterator with idle timeout detection. 

95 

96 Raises IdleTimeoutError if no message received within timeout, 

97 unless pending_tool_ids is non-empty (tool execution in progress). 

98 """ 

99 

100 def __init__( 

101 self, 

102 stream: AsyncIterator[_T], 

103 timeout_seconds: float | None, 

104 pending_tool_ids: set[str], 

105 ) -> None: 

106 self._stream: AsyncIterator[_T] = stream 

107 self._timeout_seconds = timeout_seconds 

108 self._pending_tool_ids = pending_tool_ids 

109 

110 def __aiter__(self) -> Self: 

111 return self 

112 

113 async def __anext__(self) -> _T: 

114 if self._timeout_seconds is None: 

115 return await self._stream.__anext__() 

116 # Disable timeout if tools are pending (execution in progress) 

117 current_timeout = None if self._pending_tool_ids else self._timeout_seconds 

118 try: 

119 return await asyncio.wait_for( 

120 self._stream.__anext__(), 

121 timeout=current_timeout, 

122 ) 

123 except TimeoutError as exc: 

124 raise IdleTimeoutError( 

125 f"SDK stream idle for {self._timeout_seconds:.0f} seconds" 

126 ) from exc 

127 

128 

129@dataclass 

130class MessageIterationState: 

131 """Mutable state for message iteration within a session. 

132 

133 Used to track state that evolves during SDK message streaming 

134 and idle retry handling. 

135 

136 Attributes: 

137 session_id: SDK session ID (updated when ResultMessage received). 

138 pending_session_id: Session ID to use for resuming after idle timeout. 

139 tool_calls_this_turn: Number of tool calls in the current turn. 

140 idle_retry_count: Number of idle timeout retries attempted. 

141 pending_tool_ids: Set of tool IDs awaiting results. 

142 pending_lint_commands: Map of tool_use_id to (lint_type, command). 

143 first_message_received: Whether any message was received in current turn. 

144 """ 

145 

146 session_id: str | None = None 

147 pending_session_id: str | None = None 

148 tool_calls_this_turn: int = 0 

149 idle_retry_count: int = 0 

150 pending_tool_ids: set[str] = field(default_factory=set) 

151 pending_lint_commands: dict[str, tuple[str, str]] = field(default_factory=dict) 

152 first_message_received: bool = False 

153 

154 

155@dataclass 

156class MessageIterationResult: 

157 """Result from a message iteration. 

158 

159 Attributes: 

160 success: Whether the iteration completed successfully. 

161 session_id: Updated session ID (if received). 

162 pending_query: Next query to send (for retries), or None if complete. 

163 pending_session_id: Session ID to use for next query. 

164 idle_retry_count: Updated idle retry count. 

165 """ 

166 

167 success: bool 

168 session_id: str | None = None 

169 pending_query: str | None = None 

170 pending_session_id: str | None = None 

171 idle_retry_count: int = 0 

172 

173 

174# Callbacks for SDK message streaming events 

175ToolUseCallback = Callable[[str, str, dict[str, Any] | None], None] 

176AgentTextCallback = Callable[[str, str], None] 

177 

178 

179@dataclass 

180class StreamProcessorConfig: 

181 """Configuration for MessageStreamProcessor. 

182 

183 Attributes: 

184 context_limit: Maximum context tokens for pressure detection. 

185 context_restart_threshold: Ratio (0.0-1.0) at which to raise ContextPressureError. 

186 """ 

187 

188 context_limit: int = 200_000 

189 context_restart_threshold: float = 0.90 

190 

191 

192@dataclass 

193class StreamProcessorCallbacks: 

194 """Callbacks for stream processing events. 

195 

196 Attributes: 

197 on_tool_use: Called when ToolUseBlock is encountered. 

198 on_agent_text: Called when TextBlock is encountered. 

199 """ 

200 

201 on_tool_use: ToolUseCallback | None = None 

202 on_agent_text: AgentTextCallback | None = None 

203 

204 

205class MessageStreamProcessor: 

206 """Processes SDK message streams. 

207 

208 Handles iteration over SDK streams, tracking tool calls, updating lint cache, 

209 and detecting context pressure. Uses duck typing for SDK message types to 

210 avoid SDK imports at runtime. 

211 

212 Usage: 

213 processor = MessageStreamProcessor(config, callbacks) 

214 result = await processor.process_stream( 

215 stream, issue_id, state, lifecycle_ctx, lint_cache, query_start, tracer 

216 ) 

217 """ 

218 

219 def __init__( 

220 self, 

221 config: StreamProcessorConfig | None = None, 

222 callbacks: StreamProcessorCallbacks | None = None, 

223 ) -> None: 

224 self.config = config or StreamProcessorConfig() 

225 self.callbacks = callbacks or StreamProcessorCallbacks() 

226 

227 async def process_stream( 

228 self, 

229 stream: AsyncIterator[Any], 

230 issue_id: str, 

231 state: MessageIterationState, 

232 lifecycle_ctx: LifecycleContext, 

233 lint_cache: LintCacheProtocol, 

234 query_start: float, 

235 tracer: TelemetrySpan | None, 

236 ) -> MessageIterationResult: 

237 """Process SDK message stream and update state. 

238 

239 Updates state.session_id, state.tool_calls_this_turn, state.pending_tool_ids, 

240 and lint_cache on successful lint commands. 

241 

242 Args: 

243 stream: The message stream to process. 

244 issue_id: Issue ID for logging. 

245 state: Mutable state for the iteration. 

246 lifecycle_ctx: Lifecycle context for session state. 

247 lint_cache: Cache for lint command results. 

248 query_start: Timestamp when query was sent. 

249 tracer: Optional telemetry span context. 

250 

251 Returns: 

252 MessageIterationResult with success status. 

253 

254 Raises: 

255 ContextPressureError: If context pressure exceeds threshold. 

256 """ 

257 # Use duck typing to avoid SDK imports - check type name instead of isinstance 

258 async for message in stream: 

259 if not state.first_message_received: 

260 state.first_message_received = True 

261 latency = time.time() - query_start 

262 logger.debug( 

263 "Session %s: first message after %.1fs", 

264 issue_id, 

265 latency, 

266 ) 

267 if tracer is not None: 

268 tracer.log_message(message) 

269 

270 msg_type = type(message).__name__ 

271 if msg_type == "AssistantMessage": 

272 self._process_assistant_message(message, issue_id, state, lint_cache) 

273 

274 elif msg_type == "ResultMessage": 

275 self._process_result_message(message, issue_id, state, lifecycle_ctx) 

276 

277 # Success 

278 stream_duration = time.time() - query_start 

279 logger.debug( 

280 "Session %s: stream complete after %.1fs, %d tool calls", 

281 issue_id, 

282 stream_duration, 

283 state.tool_calls_this_turn, 

284 ) 

285 return MessageIterationResult( 

286 success=True, 

287 session_id=state.session_id, 

288 idle_retry_count=0, 

289 ) 

290 

291 def _process_assistant_message( 

292 self, 

293 message: object, 

294 issue_id: str, 

295 state: MessageIterationState, 

296 lint_cache: LintCacheProtocol, 

297 ) -> None: 

298 """Process an AssistantMessage, handling text/tool blocks.""" 

299 content = getattr(message, "content", []) 

300 for block in content: 

301 block_type = type(block).__name__ 

302 if block_type == "TextBlock": 

303 text = getattr(block, "text", "") 

304 if self.callbacks.on_agent_text is not None: 

305 self.callbacks.on_agent_text(issue_id, text) 

306 elif block_type == "ToolUseBlock": 

307 state.tool_calls_this_turn += 1 

308 block_id = getattr(block, "id", "") 

309 state.pending_tool_ids.add(block_id) 

310 name = getattr(block, "name", "") 

311 block_input = getattr(block, "input", {}) 

312 if self.callbacks.on_tool_use is not None: 

313 self.callbacks.on_tool_use(issue_id, name, block_input) 

314 if name.lower() == "bash": 

315 cmd = block_input.get("command", "") 

316 lint_type = lint_cache.detect_lint_command(cmd) 

317 if lint_type: 

318 state.pending_lint_commands[block_id] = ( 

319 lint_type, 

320 cmd, 

321 ) 

322 elif block_type == "ToolResultBlock": 

323 tool_use_id = getattr(block, "tool_use_id", None) 

324 if tool_use_id: 

325 state.pending_tool_ids.discard(tool_use_id) 

326 if tool_use_id in state.pending_lint_commands: 

327 lint_type, cmd = state.pending_lint_commands.pop(tool_use_id) 

328 if not getattr(block, "is_error", False): 

329 lint_cache.mark_success(lint_type, cmd) 

330 

331 def _process_result_message( 

332 self, 

333 message: object, 

334 issue_id: str, 

335 state: MessageIterationState, 

336 lifecycle_ctx: LifecycleContext, 

337 ) -> None: 

338 """Process a ResultMessage, extracting session ID and usage. 

339 

340 Raises: 

341 ContextPressureError: If context pressure exceeds threshold. 

342 """ 

343 state.session_id = getattr(message, "session_id", None) 

344 lifecycle_ctx.session_id = state.session_id 

345 lifecycle_ctx.final_result = getattr(message, "result", "") or "" 

346 

347 # Extract token usage from SDK for context pressure detection 

348 usage = getattr(message, "usage", None) 

349 if usage is not None: 

350 # Handle both dict and object forms of usage 

351 if isinstance(usage, dict): 

352 input_tokens = usage.get("input_tokens", 0) or 0 

353 output_tokens = usage.get("output_tokens", 0) or 0 

354 cache_read = usage.get("cache_read_input_tokens", 0) or 0 

355 else: 

356 input_tokens = getattr(usage, "input_tokens", 0) or 0 

357 output_tokens = getattr(usage, "output_tokens", 0) or 0 

358 cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0 

359 

360 # Update lifecycle context with cumulative usage 

361 lifecycle_ctx.context_usage.input_tokens = input_tokens 

362 lifecycle_ctx.context_usage.output_tokens = output_tokens 

363 lifecycle_ctx.context_usage.cache_read_tokens = cache_read 

364 

365 # Check context pressure threshold 

366 pressure = lifecycle_ctx.context_usage.pressure_ratio( 

367 self.config.context_limit 

368 ) 

369 if pressure >= self.config.context_restart_threshold: 

370 # session_id was already extracted above 

371 raise ContextPressureError( 

372 session_id=state.session_id or "", 

373 input_tokens=input_tokens, 

374 output_tokens=output_tokens, 

375 cache_read_tokens=cache_read, 

376 pressure_ratio=pressure, 

377 ) 

378 else: 

379 logger.warning( 

380 "Session %s: ResultMessage missing usage field, " 

381 "context pressure tracking disabled", 

382 issue_id, 

383 ) 

384 lifecycle_ctx.context_usage.disable_tracking()