Coverage for src / infra / hooks / locking.py: 53%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-04 04:43 +0000

1"""Lock enforcement hooks for multi-agent file coordination. 

2 

3Contains hooks for enforcing file locks when writing and for cleanup 

4on agent stop. 

5""" 

6 

7from __future__ import annotations 

8 

9from collections.abc import Awaitable, Callable 

10from pathlib import Path 

11from typing import TYPE_CHECKING, Any 

12 

13if TYPE_CHECKING: 

14 from .dangerous_commands import PreToolUseHook 

15 

16from ..tools.command_runner import run_command 

17from ..tools.env import SCRIPTS_DIR, get_lock_dir 

18from ..tools.locking import get_lock_holder 

19from .file_cache import FILE_PATH_KEYS, FILE_WRITE_TOOLS 

20 

21# Type alias for Stop hooks (using Any to avoid SDK import) 

22StopHook = Callable[ 

23 [Any, str | None, Any], 

24 Awaitable[dict[str, Any]], 

25] 

26 

27 

28def make_lock_enforcement_hook( 

29 agent_id: str, repo_path: str | None = None 

30) -> PreToolUseHook: 

31 """Create a PreToolUse hook that enforces lock ownership for file writes. 

32 

33 Args: 

34 agent_id: The agent ID to check lock ownership against. 

35 repo_path: The repository root path, used as REPO_NAMESPACE for lock 

36 key computation. Must match the REPO_NAMESPACE environment variable 

37 set for the agent's shell scripts. 

38 

39 Returns: 

40 An async hook function that blocks file writes unless the agent holds the lock. 

41 """ 

42 

43 async def enforce_lock_ownership( 

44 hook_input: Any, # noqa: ANN401 - SDK type, avoid import 

45 stderr: str | None, 

46 context: Any, # noqa: ANN401 - SDK type, avoid import 

47 ) -> dict[str, Any]: 

48 """PreToolUse hook to block file writes unless this agent holds the lock.""" 

49 tool_name = hook_input["tool_name"] 

50 

51 # Only check file-write tools 

52 if tool_name not in FILE_WRITE_TOOLS: 

53 return {} 

54 

55 # Get the file path from the tool input 

56 path_key = FILE_PATH_KEYS.get(tool_name) 

57 if not path_key: 

58 return {} 

59 

60 file_path = hook_input["tool_input"].get(path_key) 

61 if not file_path: 

62 # No path provided, can't check lock - allow (tool will fail anyway) 

63 return {} 

64 

65 # Check if this agent holds the lock 

66 # Pass repo_path as repo_namespace to match shell script key computation 

67 lock_holder = get_lock_holder(file_path, repo_namespace=repo_path) 

68 

69 if lock_holder is None: 

70 return { 

71 "decision": "block", 

72 "reason": f"File {file_path} is not locked. Acquire lock with: lock-try.sh {file_path}", 

73 } 

74 

75 if lock_holder != agent_id: 

76 return { 

77 "decision": "block", 

78 "reason": f"File {file_path} is locked by {lock_holder}. Wait or coordinate to acquire the lock.", 

79 } 

80 

81 # Agent holds the lock, allow the write 

82 return {} 

83 

84 return enforce_lock_ownership 

85 

86 

87def make_stop_hook(agent_id: str) -> StopHook: 

88 """Create a Stop hook that cleans up locks for the given agent. 

89 

90 Args: 

91 agent_id: The agent ID to clean up locks for when the agent stops. 

92 

93 Returns: 

94 An async hook function suitable for use with ClaudeAgentOptions.hooks["Stop"]. 

95 """ 

96 

97 async def cleanup_locks_on_stop( 

98 hook_input: Any, # noqa: ANN401 - SDK type, avoid import 

99 stderr: str | None, 

100 context: Any, # noqa: ANN401 - SDK type, avoid import 

101 ) -> dict[str, Any]: 

102 """Stop hook to release all locks held by this agent.""" 

103 script = SCRIPTS_DIR / "lock-release-all.sh" 

104 try: 

105 run_command( 

106 [str(script)], 

107 cwd=Path.cwd(), 

108 env={ 

109 "LOCK_DIR": str(get_lock_dir()), 

110 "AGENT_ID": agent_id, 

111 }, 

112 ) 

113 except Exception: 

114 pass # Best effort cleanup, orchestrator has fallback 

115 return {} 

116 

117 return cleanup_locks_on_stop