Coverage for src / pipeline / message_stream_processor.py: 87%
132 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-04 04:43 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-04 04:43 +0000
1"""MessageStreamProcessor: SDK stream processing component.
3Extracted from AgentSessionRunner to separate stream iteration logic from
4session lifecycle management. This module handles:
5- Wrapping SDK streams with idle timeout detection
6- Iterating and processing SDK messages (AssistantMessage, ResultMessage)
7- Tracking tool calls and lint cache updates
8- Context pressure detection from usage data
10Design principles:
11- Protocol-based message/block checks for testability (no SDK imports at runtime)
12- Explicit state management via MessageIterationState
13- Callbacks for external operations (text/tool notifications)
14"""
16from __future__ import annotations
18import asyncio
19import logging
20import time
21from collections.abc import Callable
22from dataclasses import dataclass, field
23from typing import (
24 TYPE_CHECKING,
25 Any,
26 Generic,
27 Protocol,
28 TypeVar,
29)
31if TYPE_CHECKING:
32 from collections.abc import AsyncIterator
33 from typing import Self
35 from src.domain.lifecycle import LifecycleContext
36 from src.infra.telemetry import TelemetrySpan
39class LintCacheProtocol(Protocol):
40 """Protocol for lint cache operations used by stream processor."""
42 def detect_lint_command(self, command: str) -> str | None:
43 """Detect if command is a lint command and return lint type."""
44 ...
46 def mark_success(self, lint_type: str, command: str) -> None:
47 """Mark a lint command as successful."""
48 ...
51logger = logging.getLogger(__name__)
53_T = TypeVar("_T")
56class IdleTimeoutError(Exception):
57 """Raised when the SDK response stream is idle for too long."""
60class ContextPressureError(Exception):
61 """Raised when context usage exceeds the restart threshold.
63 This exception signals that the agent session should be checkpointed
64 and restarted with a fresh context to avoid context exhaustion.
66 Attributes:
67 session_id: SDK session ID for checkpoint query.
68 input_tokens: Current input token count.
69 output_tokens: Current output token count.
70 cache_read_tokens: Current cache read token count.
71 pressure_ratio: Ratio of usage to limit (e.g., 0.92 = 92%).
72 """
74 def __init__(
75 self,
76 session_id: str,
77 input_tokens: int,
78 output_tokens: int,
79 cache_read_tokens: int,
80 pressure_ratio: float,
81 ) -> None:
82 self.session_id = session_id
83 self.input_tokens = input_tokens
84 self.output_tokens = output_tokens
85 self.cache_read_tokens = cache_read_tokens
86 self.pressure_ratio = pressure_ratio
87 super().__init__(
88 f"Context pressure {pressure_ratio:.1%} exceeds threshold "
89 f"(input={input_tokens}, output={output_tokens}, session={session_id})"
90 )
93class IdleTimeoutStream(Generic[_T]):
94 """Wrap an async iterator with idle timeout detection.
96 Raises IdleTimeoutError if no message received within timeout,
97 unless pending_tool_ids is non-empty (tool execution in progress).
98 """
100 def __init__(
101 self,
102 stream: AsyncIterator[_T],
103 timeout_seconds: float | None,
104 pending_tool_ids: set[str],
105 ) -> None:
106 self._stream: AsyncIterator[_T] = stream
107 self._timeout_seconds = timeout_seconds
108 self._pending_tool_ids = pending_tool_ids
110 def __aiter__(self) -> Self:
111 return self
113 async def __anext__(self) -> _T:
114 if self._timeout_seconds is None:
115 return await self._stream.__anext__()
116 # Disable timeout if tools are pending (execution in progress)
117 current_timeout = None if self._pending_tool_ids else self._timeout_seconds
118 try:
119 return await asyncio.wait_for(
120 self._stream.__anext__(),
121 timeout=current_timeout,
122 )
123 except TimeoutError as exc:
124 raise IdleTimeoutError(
125 f"SDK stream idle for {self._timeout_seconds:.0f} seconds"
126 ) from exc
129@dataclass
130class MessageIterationState:
131 """Mutable state for message iteration within a session.
133 Used to track state that evolves during SDK message streaming
134 and idle retry handling.
136 Attributes:
137 session_id: SDK session ID (updated when ResultMessage received).
138 pending_session_id: Session ID to use for resuming after idle timeout.
139 tool_calls_this_turn: Number of tool calls in the current turn.
140 idle_retry_count: Number of idle timeout retries attempted.
141 pending_tool_ids: Set of tool IDs awaiting results.
142 pending_lint_commands: Map of tool_use_id to (lint_type, command).
143 first_message_received: Whether any message was received in current turn.
144 """
146 session_id: str | None = None
147 pending_session_id: str | None = None
148 tool_calls_this_turn: int = 0
149 idle_retry_count: int = 0
150 pending_tool_ids: set[str] = field(default_factory=set)
151 pending_lint_commands: dict[str, tuple[str, str]] = field(default_factory=dict)
152 first_message_received: bool = False
155@dataclass
156class MessageIterationResult:
157 """Result from a message iteration.
159 Attributes:
160 success: Whether the iteration completed successfully.
161 session_id: Updated session ID (if received).
162 pending_query: Next query to send (for retries), or None if complete.
163 pending_session_id: Session ID to use for next query.
164 idle_retry_count: Updated idle retry count.
165 """
167 success: bool
168 session_id: str | None = None
169 pending_query: str | None = None
170 pending_session_id: str | None = None
171 idle_retry_count: int = 0
174# Callbacks for SDK message streaming events
175ToolUseCallback = Callable[[str, str, dict[str, Any] | None], None]
176AgentTextCallback = Callable[[str, str], None]
179@dataclass
180class StreamProcessorConfig:
181 """Configuration for MessageStreamProcessor.
183 Attributes:
184 context_limit: Maximum context tokens for pressure detection.
185 context_restart_threshold: Ratio (0.0-1.0) at which to raise ContextPressureError.
186 """
188 context_limit: int = 200_000
189 context_restart_threshold: float = 0.90
192@dataclass
193class StreamProcessorCallbacks:
194 """Callbacks for stream processing events.
196 Attributes:
197 on_tool_use: Called when ToolUseBlock is encountered.
198 on_agent_text: Called when TextBlock is encountered.
199 """
201 on_tool_use: ToolUseCallback | None = None
202 on_agent_text: AgentTextCallback | None = None
205class MessageStreamProcessor:
206 """Processes SDK message streams.
208 Handles iteration over SDK streams, tracking tool calls, updating lint cache,
209 and detecting context pressure. Uses duck typing for SDK message types to
210 avoid SDK imports at runtime.
212 Usage:
213 processor = MessageStreamProcessor(config, callbacks)
214 result = await processor.process_stream(
215 stream, issue_id, state, lifecycle_ctx, lint_cache, query_start, tracer
216 )
217 """
219 def __init__(
220 self,
221 config: StreamProcessorConfig | None = None,
222 callbacks: StreamProcessorCallbacks | None = None,
223 ) -> None:
224 self.config = config or StreamProcessorConfig()
225 self.callbacks = callbacks or StreamProcessorCallbacks()
227 async def process_stream(
228 self,
229 stream: AsyncIterator[Any],
230 issue_id: str,
231 state: MessageIterationState,
232 lifecycle_ctx: LifecycleContext,
233 lint_cache: LintCacheProtocol,
234 query_start: float,
235 tracer: TelemetrySpan | None,
236 ) -> MessageIterationResult:
237 """Process SDK message stream and update state.
239 Updates state.session_id, state.tool_calls_this_turn, state.pending_tool_ids,
240 and lint_cache on successful lint commands.
242 Args:
243 stream: The message stream to process.
244 issue_id: Issue ID for logging.
245 state: Mutable state for the iteration.
246 lifecycle_ctx: Lifecycle context for session state.
247 lint_cache: Cache for lint command results.
248 query_start: Timestamp when query was sent.
249 tracer: Optional telemetry span context.
251 Returns:
252 MessageIterationResult with success status.
254 Raises:
255 ContextPressureError: If context pressure exceeds threshold.
256 """
257 # Use duck typing to avoid SDK imports - check type name instead of isinstance
258 async for message in stream:
259 if not state.first_message_received:
260 state.first_message_received = True
261 latency = time.time() - query_start
262 logger.debug(
263 "Session %s: first message after %.1fs",
264 issue_id,
265 latency,
266 )
267 if tracer is not None:
268 tracer.log_message(message)
270 msg_type = type(message).__name__
271 if msg_type == "AssistantMessage":
272 self._process_assistant_message(message, issue_id, state, lint_cache)
274 elif msg_type == "ResultMessage":
275 self._process_result_message(message, issue_id, state, lifecycle_ctx)
277 # Success
278 stream_duration = time.time() - query_start
279 logger.debug(
280 "Session %s: stream complete after %.1fs, %d tool calls",
281 issue_id,
282 stream_duration,
283 state.tool_calls_this_turn,
284 )
285 return MessageIterationResult(
286 success=True,
287 session_id=state.session_id,
288 idle_retry_count=0,
289 )
291 def _process_assistant_message(
292 self,
293 message: object,
294 issue_id: str,
295 state: MessageIterationState,
296 lint_cache: LintCacheProtocol,
297 ) -> None:
298 """Process an AssistantMessage, handling text/tool blocks."""
299 content = getattr(message, "content", [])
300 for block in content:
301 block_type = type(block).__name__
302 if block_type == "TextBlock":
303 text = getattr(block, "text", "")
304 if self.callbacks.on_agent_text is not None:
305 self.callbacks.on_agent_text(issue_id, text)
306 elif block_type == "ToolUseBlock":
307 state.tool_calls_this_turn += 1
308 block_id = getattr(block, "id", "")
309 state.pending_tool_ids.add(block_id)
310 name = getattr(block, "name", "")
311 block_input = getattr(block, "input", {})
312 if self.callbacks.on_tool_use is not None:
313 self.callbacks.on_tool_use(issue_id, name, block_input)
314 if name.lower() == "bash":
315 cmd = block_input.get("command", "")
316 lint_type = lint_cache.detect_lint_command(cmd)
317 if lint_type:
318 state.pending_lint_commands[block_id] = (
319 lint_type,
320 cmd,
321 )
322 elif block_type == "ToolResultBlock":
323 tool_use_id = getattr(block, "tool_use_id", None)
324 if tool_use_id:
325 state.pending_tool_ids.discard(tool_use_id)
326 if tool_use_id in state.pending_lint_commands:
327 lint_type, cmd = state.pending_lint_commands.pop(tool_use_id)
328 if not getattr(block, "is_error", False):
329 lint_cache.mark_success(lint_type, cmd)
331 def _process_result_message(
332 self,
333 message: object,
334 issue_id: str,
335 state: MessageIterationState,
336 lifecycle_ctx: LifecycleContext,
337 ) -> None:
338 """Process a ResultMessage, extracting session ID and usage.
340 Raises:
341 ContextPressureError: If context pressure exceeds threshold.
342 """
343 state.session_id = getattr(message, "session_id", None)
344 lifecycle_ctx.session_id = state.session_id
345 lifecycle_ctx.final_result = getattr(message, "result", "") or ""
347 # Extract token usage from SDK for context pressure detection
348 usage = getattr(message, "usage", None)
349 if usage is not None:
350 # Handle both dict and object forms of usage
351 if isinstance(usage, dict):
352 input_tokens = usage.get("input_tokens", 0) or 0
353 output_tokens = usage.get("output_tokens", 0) or 0
354 cache_read = usage.get("cache_read_input_tokens", 0) or 0
355 else:
356 input_tokens = getattr(usage, "input_tokens", 0) or 0
357 output_tokens = getattr(usage, "output_tokens", 0) or 0
358 cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
360 # Update lifecycle context with cumulative usage
361 lifecycle_ctx.context_usage.input_tokens = input_tokens
362 lifecycle_ctx.context_usage.output_tokens = output_tokens
363 lifecycle_ctx.context_usage.cache_read_tokens = cache_read
365 # Check context pressure threshold
366 pressure = lifecycle_ctx.context_usage.pressure_ratio(
367 self.config.context_limit
368 )
369 if pressure >= self.config.context_restart_threshold:
370 # session_id was already extracted above
371 raise ContextPressureError(
372 session_id=state.session_id or "",
373 input_tokens=input_tokens,
374 output_tokens=output_tokens,
375 cache_read_tokens=cache_read,
376 pressure_ratio=pressure,
377 )
378 else:
379 logger.warning(
380 "Session %s: ResultMessage missing usage field, "
381 "context pressure tracking disabled",
382 issue_id,
383 )
384 lifecycle_ctx.context_usage.disable_tracking()