Coverage for src / domain / deadlock.py: 94%

145 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-04 04:43 +0000

1"""Deadlock detection domain model. 

2 

3Provides WaitForGraph and DeadlockMonitor for detecting cycles in lock 

4acquisition patterns among parallel agents. 

5 

6The WaitForGraph tracks: 

7- Which agents hold which locks (holds: dict[lock_path, agent_id]) 

8- Which agents are waiting for which locks (waits: dict[agent_id, lock_path]) 

9 

10Cycle detection uses DFS from waiting agents to find circular dependencies. 

11""" 

12 

13from __future__ import annotations 

14 

15import asyncio 

16import logging 

17from collections.abc import Awaitable, Callable 

18from dataclasses import dataclass 

19from typing import TYPE_CHECKING 

20 

21from src.core.models import LockEventType 

22 

23if TYPE_CHECKING: 

24 from collections.abc import Sequence 

25 

26 from src.core.models import LockEvent 

27 

28logger = logging.getLogger(__name__) 

29 

30__all__ = [ 

31 "AgentInfo", 

32 "DeadlockCallback", 

33 "DeadlockInfo", 

34 "DeadlockMonitor", 

35 "WaitForGraph", 

36] 

37 

38# Type alias for the deadlock callback 

39DeadlockCallback = Callable[["DeadlockInfo"], Awaitable[None] | None] 

40 

41 

42@dataclass 

43class DeadlockInfo: 

44 """Information about a detected deadlock. 

45 

46 Attributes: 

47 cycle: List of agent IDs forming the deadlock cycle. 

48 victim_id: Agent ID selected to be killed (youngest in cycle). 

49 victim_issue_id: Issue ID the victim was working on. 

50 blocked_on: Lock path the victim was waiting for. 

51 blocker_id: Agent ID holding the lock the victim needs. 

52 blocker_issue_id: Issue ID the blocker was working on. 

53 """ 

54 

55 cycle: list[str] 

56 victim_id: str 

57 victim_issue_id: str | None 

58 blocked_on: str 

59 blocker_id: str 

60 blocker_issue_id: str | None 

61 

62 

63@dataclass 

64class AgentInfo: 

65 """Metadata about a registered agent. 

66 

67 Attributes: 

68 agent_id: Unique identifier for the agent. 

69 issue_id: Issue ID the agent is working on. 

70 start_time: Unix timestamp when the agent was registered. 

71 """ 

72 

73 agent_id: str 

74 issue_id: str | None 

75 start_time: float 

76 

77 

78class WaitForGraph: 

79 """Graph tracking lock holds and waits for cycle detection. 

80 

81 The graph maintains two mappings: 

82 - holds: lock_path -> agent_id (who holds each lock) 

83 - waits: agent_id -> lock_path (what each agent is waiting for) 

84 

85 Cycle detection walks from a waiting agent through the hold/wait 

86 edges to find circular dependencies. 

87 """ 

88 

89 def __init__(self) -> None: 

90 """Initialize empty graph.""" 

91 self._holds: dict[str, str] = {} # lock_path -> agent_id 

92 self._waits: dict[str, str] = {} # agent_id -> lock_path 

93 

94 def add_hold(self, agent_id: str, lock_path: str) -> None: 

95 """Record that an agent holds a lock. 

96 

97 Args: 

98 agent_id: The agent that acquired the lock. 

99 lock_path: Path to the lock. 

100 """ 

101 existing_holder = self._holds.get(lock_path) 

102 if existing_holder is not None and existing_holder != agent_id: 

103 logger.warning( 

104 "Invariant: ACQUIRED for lock held by other agent: " 

105 "lock=%s holder=%s new_agent=%s", 

106 lock_path, 

107 existing_holder, 

108 agent_id, 

109 ) 

110 self._holds[lock_path] = agent_id 

111 logger.debug("Lock acquired: agent_id=%s lock_path=%s", agent_id, lock_path) 

112 # Clear wait if this agent was waiting for this lock 

113 if self._waits.get(agent_id) == lock_path: 

114 del self._waits[agent_id] 

115 

116 def add_wait(self, agent_id: str, lock_path: str) -> None: 

117 """Record that an agent is waiting for a lock. 

118 

119 Args: 

120 agent_id: The agent that is waiting. 

121 lock_path: Path to the lock being waited on. 

122 """ 

123 # Check for invariant violations 

124 if self._holds.get(lock_path) == agent_id: 

125 logger.warning( 

126 "Invariant: WAITING on lock already held by same agent: " 

127 "agent=%s lock=%s", 

128 agent_id, 

129 lock_path, 

130 ) 

131 old_wait = self._waits.get(agent_id) 

132 if old_wait is not None and old_wait != lock_path: 

133 logger.warning( 

134 "Wait edge overwritten: agent_id=%s old_lock=%s new_lock=%s", 

135 agent_id, 

136 old_wait, 

137 lock_path, 

138 ) 

139 self._waits[agent_id] = lock_path 

140 logger.debug("Wait added: agent_id=%s lock_path=%s", agent_id, lock_path) 

141 

142 def remove_hold(self, agent_id: str, lock_path: str) -> None: 

143 """Remove a hold record when a lock is released. 

144 

145 Args: 

146 agent_id: The agent releasing the lock. 

147 lock_path: Path to the lock being released. 

148 """ 

149 current_holder = self._holds.get(lock_path) 

150 if current_holder != agent_id: 

151 logger.warning( 

152 "Invariant: RELEASED for lock not held by agent: " 

153 "lock=%s holder=%s agent=%s", 

154 lock_path, 

155 current_holder, 

156 agent_id, 

157 ) 

158 if current_holder == agent_id: 

159 del self._holds[lock_path] 

160 logger.debug("Lock released: agent_id=%s lock_path=%s", agent_id, lock_path) 

161 

162 def remove_agent(self, agent_id: str) -> None: 

163 """Remove all state for an agent. 

164 

165 Called when an agent exits (success or failure). 

166 

167 Args: 

168 agent_id: The agent to remove. 

169 """ 

170 # Remove wait entry 

171 if agent_id in self._waits: 

172 del self._waits[agent_id] 

173 # Remove all holds by this agent 

174 locks_to_remove = [ 

175 lock for lock, holder in self._holds.items() if holder == agent_id 

176 ] 

177 for lock in locks_to_remove: 

178 del self._holds[lock] 

179 

180 def get_holder(self, lock_path: str) -> str | None: 

181 """Get the agent holding a lock. 

182 

183 Args: 

184 lock_path: Path to the lock. 

185 

186 Returns: 

187 Agent ID if the lock is held, None otherwise. 

188 """ 

189 return self._holds.get(lock_path) 

190 

191 def get_waited_lock(self, agent_id: str) -> str | None: 

192 """Get the lock an agent is waiting for. 

193 

194 Args: 

195 agent_id: The agent ID. 

196 

197 Returns: 

198 Lock path if the agent is waiting, None otherwise. 

199 """ 

200 return self._waits.get(agent_id) 

201 

202 def detect_cycle(self) -> list[str] | None: 

203 """Detect a deadlock cycle in the wait-for graph. 

204 

205 Uses single-pass DFS with three-color marking to achieve O(n) time 

206 complexity where n is the number of waiting agents. Each agent is 

207 fully processed at most once across all DFS starts. 

208 

209 Colors: 

210 - WHITE (not in any set): unvisited 

211 - GRAY (in path): currently being explored in this DFS path 

212 - BLACK (in safe): fully explored, proven not to lead to a cycle 

213 

214 Returns: 

215 List of agent IDs in the cycle if found, None otherwise. 

216 The cycle is returned in order of discovery (first agent 

217 is where the cycle was detected). 

218 """ 

219 safe: set[str] = set() # BLACK: agents proven not in any cycle 

220 

221 for start_agent in self._waits: 

222 if start_agent in safe: 

223 continue 

224 

225 cycle = self._find_cycle_from(start_agent, safe) 

226 if cycle: 

227 return cycle 

228 return None 

229 

230 def _find_cycle_from(self, start_agent: str, safe: set[str]) -> list[str] | None: 

231 """DFS from a single agent to find a cycle. 

232 

233 Updates the safe set with agents proven not to lead to a cycle. 

234 

235 Args: 

236 start_agent: Agent to start searching from. 

237 safe: Set of agents already proven not to lead to a cycle. 

238 

239 Returns: 

240 Cycle path if found, None otherwise. 

241 """ 

242 path: list[str] = [] 

243 path_set: set[str] = set() # GRAY: agents in current path 

244 

245 current = start_agent 

246 while True: 

247 if current in safe: 

248 # Reached a node proven safe, entire path is safe 

249 safe.update(path_set) 

250 return None 

251 

252 if current in path_set: 

253 # Found a cycle - extract it from the path 

254 cycle_start_idx = path.index(current) 

255 return path[cycle_start_idx:] 

256 

257 # What lock is this agent waiting for? 

258 lock_waiting = self._waits.get(current) 

259 if lock_waiting is None: 

260 # Agent not waiting for anything, path is safe 

261 safe.update(path_set) 

262 return None 

263 

264 # Who holds that lock? 

265 holder = self._holds.get(lock_waiting) 

266 if holder is None: 

267 # Lock not held, path is safe 

268 safe.update(path_set) 

269 return None 

270 

271 path.append(current) 

272 path_set.add(current) 

273 current = holder 

274 

275 

276class DeadlockMonitor: 

277 """Orchestrates deadlock detection and victim selection. 

278 

279 Maintains a registry of active agents and their metadata, handles 

280 lock events to update the wait-for graph, and selects victims 

281 when deadlocks are detected. 

282 

283 Victim selection picks the youngest agent (highest start_time) in 

284 the cycle to minimize wasted work. 

285 

286 The on_deadlock callback is invoked when a deadlock is detected. 

287 If set, handle_event will call it with the DeadlockInfo. The 

288 callback may be sync or async. 

289 """ 

290 

291 def __init__(self) -> None: 

292 """Initialize the monitor with empty state.""" 

293 self._graph = WaitForGraph() 

294 self._agents: dict[str, AgentInfo] = {} 

295 self.on_deadlock: DeadlockCallback | None = None 

296 

297 def register_agent( 

298 self, agent_id: str, issue_id: str | None, start_time: float 

299 ) -> None: 

300 """Register an agent with the monitor. 

301 

302 Args: 

303 agent_id: Unique identifier for the agent. 

304 issue_id: Issue the agent is working on (may be None). 

305 start_time: Unix timestamp when the agent started. 

306 """ 

307 self._agents[agent_id] = AgentInfo( 

308 agent_id=agent_id, 

309 issue_id=issue_id, 

310 start_time=start_time, 

311 ) 

312 logger.info("Agent registered: agent_id=%s issue_id=%s", agent_id, issue_id) 

313 

314 def unregister_agent(self, agent_id: str) -> None: 

315 """Unregister an agent and clear its state. 

316 

317 Args: 

318 agent_id: Agent to unregister. 

319 """ 

320 self._graph.remove_agent(agent_id) 

321 if agent_id in self._agents: 

322 del self._agents[agent_id] 

323 logger.info("Agent unregistered: agent_id=%s", agent_id) 

324 

325 async def handle_event(self, event: LockEvent) -> DeadlockInfo | None: 

326 """Process a lock event and check for deadlocks. 

327 

328 Updates the wait-for graph based on the event type, then checks 

329 for cycles if the event indicates waiting. If a deadlock is detected 

330 and on_deadlock is set, invokes the callback. 

331 

332 Args: 

333 event: The lock event to process. 

334 

335 Returns: 

336 DeadlockInfo if a deadlock is detected, None otherwise. 

337 """ 

338 # Check for events from unregistered agents 

339 if event.agent_id not in self._agents: 

340 logger.warning("Event for unregistered agent: agent_id=%s", event.agent_id) 

341 

342 logger.debug( 

343 "Event received: type=%s agent_id=%s lock_path=%s", 

344 event.event_type.value, 

345 event.agent_id, 

346 event.lock_path, 

347 ) 

348 

349 if event.event_type == LockEventType.ACQUIRED: 

350 self._graph.add_hold(event.agent_id, event.lock_path) 

351 elif event.event_type == LockEventType.WAITING: 

352 self._graph.add_wait(event.agent_id, event.lock_path) 

353 # Check for deadlock after adding wait 

354 deadlock_info = self._check_for_deadlock(event.agent_id, event.lock_path) 

355 if deadlock_info is not None and self.on_deadlock is not None: 

356 result = self.on_deadlock(deadlock_info) 

357 if asyncio.iscoroutine(result): 

358 await result 

359 logger.debug( 

360 "Graph updated: holds=%d waits=%d", 

361 len(self._graph._holds), 

362 len(self._graph._waits), 

363 ) 

364 return deadlock_info 

365 elif event.event_type == LockEventType.RELEASED: 

366 self._graph.remove_hold(event.agent_id, event.lock_path) 

367 

368 logger.debug( 

369 "Graph updated: holds=%d waits=%d", 

370 len(self._graph._holds), 

371 len(self._graph._waits), 

372 ) 

373 return None 

374 

375 def _check_for_deadlock( 

376 self, waiting_agent: str, lock_path: str 

377 ) -> DeadlockInfo | None: 

378 """Check for deadlock and select victim if found. 

379 

380 Args: 

381 waiting_agent: Agent that just started waiting. 

382 lock_path: Lock the agent is waiting for. 

383 

384 Returns: 

385 DeadlockInfo with victim selection if deadlock detected. 

386 """ 

387 cycle = self._graph.detect_cycle() 

388 logger.debug("Cycle check: found=%s", cycle is not None) 

389 if not cycle: 

390 return None 

391 

392 logger.warning("Cycle detected: agents=%s", cycle) 

393 

394 # Select victim: youngest agent (max start_time) in cycle 

395 victim = self._select_victim(cycle) 

396 if victim is None: 

397 # No registered agents in cycle (shouldn't happen) 

398 return None 

399 

400 # Find what the victim is blocked on (use victim's wait, not triggering lock) 

401 victim_info = self._agents.get(victim.agent_id) 

402 victim_waited_lock = self._graph.get_waited_lock(victim.agent_id) 

403 blocked_on = victim_waited_lock or lock_path 

404 blocker_id = self._graph.get_holder(blocked_on) 

405 blocker_info = self._agents.get(blocker_id) if blocker_id else None 

406 

407 return DeadlockInfo( 

408 cycle=cycle, 

409 victim_id=victim.agent_id, 

410 victim_issue_id=victim_info.issue_id if victim_info else None, 

411 blocked_on=blocked_on, 

412 blocker_id=blocker_id or "", 

413 blocker_issue_id=blocker_info.issue_id if blocker_info else None, 

414 ) 

415 

416 def _select_victim(self, cycle: Sequence[str]) -> AgentInfo | None: 

417 """Select the victim from a deadlock cycle. 

418 

419 Picks the youngest agent (highest start_time) to minimize wasted work. 

420 

421 Args: 

422 cycle: List of agent IDs in the deadlock cycle. 

423 

424 Returns: 

425 AgentInfo for the selected victim, or None if no registered agents. 

426 """ 

427 candidates = [self._agents[a] for a in cycle if a in self._agents] 

428 if not candidates: 

429 return None 

430 victim = max(candidates, key=lambda a: a.start_time) 

431 logger.info( 

432 "Victim selected: agent_id=%s start_time=%f (youngest in cycle)", 

433 victim.agent_id, 

434 victim.start_time, 

435 ) 

436 return victim