Coverage for src / pipeline / context_pressure_handler.py: 100%
64 statements
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-04 04:43 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2026-01-04 04:43 +0000
1"""ContextPressureHandler: Context pressure checkpoint and restart handling.
3Extracted from AgentSessionRunner to separate checkpoint/restart logic from
4session lifecycle management. This module handles:
5- Fetching checkpoints from agents via SDK client
6- Building continuation prompts with checkpoint context
7- Managing restart loop state (continuation count, prompts)
9Design principles:
10- Protocol-based SDK client for testability
11- Explicit error type for pressure threshold detection
12- Timeout handling for checkpoint fetch operations
13"""
15from __future__ import annotations
17import asyncio
18import logging
19from dataclasses import dataclass
20from typing import (
21 TYPE_CHECKING,
22 Any,
23 cast,
24)
26from src.domain.prompts import (
27 build_continuation_prompt,
28 extract_checkpoint,
29)
30from src.pipeline.message_stream_processor import ContextPressureError
32if TYPE_CHECKING:
33 from src.core.protocols import (
34 SDKClientFactoryProtocol,
35 )
38logger = logging.getLogger(__name__)
40# Default timeout for checkpoint fetch operations (30 seconds)
41DEFAULT_CHECKPOINT_TIMEOUT_SECONDS = 30
44@dataclass
45class ContextPressureConfig:
46 """Configuration for context pressure handling.
48 Attributes:
49 checkpoint_request_prompt: Prompt template to request checkpoint from agent.
50 continuation_template: Template for continuation prompt with checkpoint.
51 checkpoint_timeout_seconds: Timeout for checkpoint fetch operations.
52 """
54 checkpoint_request_prompt: str
55 continuation_template: str
56 checkpoint_timeout_seconds: float = DEFAULT_CHECKPOINT_TIMEOUT_SECONDS
59@dataclass
60class CheckpointResult:
61 """Result of fetching a checkpoint from an agent.
63 Attributes:
64 checkpoint: Extracted checkpoint text (may be empty).
65 timed_out: Whether the fetch timed out.
66 """
68 checkpoint: str
69 timed_out: bool = False
72class ContextPressureHandler:
73 """Handles context pressure detection and checkpoint/restart logic.
75 This handler encapsulates:
76 - Fetching checkpoints from agents before restart
77 - Building continuation prompts with checkpoint context
78 - Managing restart loop state
80 The handler is stateless per-call; restart state is managed by the caller.
81 """
83 def __init__(
84 self,
85 config: ContextPressureConfig,
86 sdk_client_factory: SDKClientFactoryProtocol,
87 ) -> None:
88 """Initialize the handler.
90 Args:
91 config: Context pressure configuration.
92 sdk_client_factory: Factory for creating SDK clients.
93 """
94 self.config = config
95 self.sdk_client_factory = sdk_client_factory
97 async def fetch_checkpoint(
98 self,
99 session_id: str,
100 issue_id: str,
101 options: object,
102 timeout_seconds: float | None = None,
103 ) -> CheckpointResult:
104 """Fetch checkpoint from agent before context restart.
106 Sends checkpoint_request_prompt to the current session and extracts
107 the checkpoint block from the response.
109 Args:
110 session_id: SDK session ID from ContextPressureError.
111 issue_id: Issue ID for logging.
112 options: SDK client options.
113 timeout_seconds: Optional override for checkpoint timeout.
115 Returns:
116 CheckpointResult with extracted checkpoint text.
117 """
118 effective_timeout = (
119 timeout_seconds
120 if timeout_seconds is not None
121 else self.config.checkpoint_timeout_seconds
122 )
124 logger.info(
125 "Session %s: requesting checkpoint from session %s...",
126 issue_id,
127 session_id[:8] if session_id else "unknown",
128 )
130 checkpoint_prompt = self.config.checkpoint_request_prompt
131 if not checkpoint_prompt:
132 logger.warning(
133 "Session %s: no checkpoint_request prompt configured, using empty checkpoint",
134 issue_id,
135 )
136 return CheckpointResult(checkpoint="", timed_out=False)
138 # Create client to query for checkpoint
139 client = self.sdk_client_factory.create(options)
141 response_text = ""
142 try:
143 async with asyncio.timeout(effective_timeout):
144 async with client:
145 await client.query(checkpoint_prompt, session_id=session_id)
146 async for message in client.receive_response():
147 # Extract text from AssistantMessage
148 content = getattr(message, "content", None)
149 if content is not None:
150 for block in cast("list[Any]", content):
151 text = getattr(block, "text", None)
152 if text is not None:
153 response_text += text
154 except TimeoutError:
155 logger.warning(
156 "Session %s: checkpoint fetch timed out after %.1fs, using empty checkpoint",
157 issue_id,
158 effective_timeout,
159 )
160 return CheckpointResult(checkpoint="", timed_out=True)
161 except Exception as e:
162 logger.warning(
163 "Session %s: checkpoint query failed: %s, using empty checkpoint",
164 issue_id,
165 e,
166 )
167 return CheckpointResult(checkpoint="", timed_out=False)
169 # Extract checkpoint from response
170 checkpoint = extract_checkpoint(response_text)
171 logger.debug(
172 "Session %s: extracted checkpoint (%d chars)",
173 issue_id,
174 len(checkpoint),
175 )
176 return CheckpointResult(checkpoint=checkpoint, timed_out=False)
178 def build_continuation_prompt(self, checkpoint: str) -> str:
179 """Build continuation prompt with checkpoint context.
181 Args:
182 checkpoint: Extracted checkpoint text from previous session.
184 Returns:
185 Formatted continuation prompt for restart.
186 """
187 continuation_template = self.config.continuation_template
188 if continuation_template:
189 return build_continuation_prompt(continuation_template, checkpoint)
190 # Fallback: just use checkpoint as prompt
191 return f"Continue from checkpoint:\n\n{checkpoint}"
193 async def handle_pressure_error(
194 self,
195 error: ContextPressureError,
196 issue_id: str,
197 options: object,
198 continuation_count: int,
199 remaining_time: float,
200 ) -> tuple[str, int]:
201 """Handle a ContextPressureError by fetching checkpoint and building continuation.
203 This is the main entry point for handling context pressure. It:
204 1. Calculates effective timeout based on remaining session time
205 2. Fetches checkpoint from the current session
206 3. Builds the continuation prompt
207 4. Logs restart information
209 Args:
210 error: The ContextPressureError that triggered the restart.
211 issue_id: Issue ID for logging.
212 options: SDK client options.
213 continuation_count: Current restart count (will be incremented).
214 remaining_time: Remaining session time in seconds.
216 Returns:
217 Tuple of (continuation_prompt, new_continuation_count).
218 """
219 # Calculate effective timeout bounded by remaining session time
220 effective_timeout = max(
221 0, min(remaining_time, self.config.checkpoint_timeout_seconds)
222 )
224 # Fetch checkpoint from current session
225 result = await self.fetch_checkpoint(
226 session_id=error.session_id,
227 issue_id=issue_id,
228 options=options,
229 timeout_seconds=effective_timeout,
230 )
232 new_count = continuation_count + 1
233 logger.info(
234 "Session %s: context restart #%d at %.1f%%",
235 issue_id,
236 new_count,
237 error.pressure_ratio * 100,
238 )
240 # Build continuation prompt
241 continuation_prompt = self.build_continuation_prompt(result.checkpoint)
243 return continuation_prompt, new_count
246# Re-export ContextPressureError for convenience
247__all__ = [
248 "DEFAULT_CHECKPOINT_TIMEOUT_SECONDS",
249 "CheckpointResult",
250 "ContextPressureConfig",
251 "ContextPressureError",
252 "ContextPressureHandler",
253]