Coverage for src / infra / clients / braintrust_integration.py: 33%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2026-01-04 04:43 +0000

1""" 

2Braintrust integration for Claude Agent SDK. 

3 

4LLM spans are automatically traced by the braintrust.wrappers.claude_agent_sdk wrapper 

5which is set up in cli.py BEFORE importing claude_agent_sdk. 

6 

7This module provides: 

8- TracedAgentExecution: Context manager for creating parent spans with issue metadata 

9- flush_braintrust: Ensure all traces are sent before process exits 

10 

11Usage: 

12 # In cli.py (BEFORE importing claude_agent_sdk): 

13 from braintrust.wrappers.claude_agent_sdk import setup_claude_agent_sdk 

14 setup_claude_agent_sdk(project="mala") 

15 

16 # Then in agent code: 

17 from .braintrust_integration import TracedAgentExecution, flush_braintrust 

18 

19 with TracedAgentExecution(issue_id, agent_id) as tracer: 

20 tracer.log_input(prompt) 

21 async with ClaudeSDKClient(options=options) as client: 

22 await client.query(prompt) 

23 # LLM calls are auto-traced by the wrapper 

24 async for message in client.receive_response(): 

25 tracer.log_message(message) # For output/tool tracking 

26 tracer.set_success(True) 

27""" 

28 

29from __future__ import annotations 

30 

31import os 

32from typing import TYPE_CHECKING, Any, Self 

33 

34if TYPE_CHECKING: 

35 import types 

36 from collections.abc import Callable 

37 from types import TracebackType 

38 

39 from claude_agent_sdk import ( 

40 AssistantMessage, 

41 ResultMessage, 

42 ) 

43 

44# Type alias for SDK messages (all types from receive_response) 

45# We only handle AssistantMessage and ResultMessage, but accept all for type safety 

46SDKMessage = object # receive_response yields multiple types 

47 

48# Braintrust imports - gracefully handle if not configured 

49BRAINTRUST_AVAILABLE: bool = False 

50braintrust: types.ModuleType | None = None 

51start_span: Callable[..., object] | None = None 

52flush: Callable[[], None] | None = None 

53 

54try: 

55 import braintrust as _braintrust 

56 from braintrust import start_span as _start_span 

57 

58 BRAINTRUST_AVAILABLE = True 

59 braintrust = _braintrust 

60 start_span = _start_span 

61 flush = _braintrust.flush 

62except ImportError: 

63 pass 

64 

65 

66def is_braintrust_enabled() -> bool: 

67 """Check if Braintrust is available and configured.""" 

68 return BRAINTRUST_AVAILABLE and os.environ.get("BRAINTRUST_API_KEY") is not None 

69 

70 

71def flush_braintrust() -> None: 

72 """Flush pending logs to Braintrust.""" 

73 if BRAINTRUST_AVAILABLE and flush is not None: 

74 try: 

75 flush() 

76 except Exception: 

77 pass 

78 

79 

80class TracedAgentExecution: 

81 """ 

82 Context manager for tracing a single agent execution. 

83 

84 Creates a parent span for the agent with issue-level metadata. 

85 LLM calls within this context are automatically traced by the wrapper. 

86 

87 Captures: 

88 - Initial prompt (input) 

89 - Tool call counts 

90 - Final result (output) 

91 - Success/failure status 

92 """ 

93 

94 def __init__( 

95 self, issue_id: str, agent_id: str, metadata: dict[str, Any] | None = None 

96 ): 

97 self.issue_id = issue_id 

98 self.agent_id = agent_id 

99 self.metadata = metadata or {} 

100 self.span: Any = None # Braintrust span object (dynamic type) 

101 self.input_prompt: str | None = None 

102 self.output_text: str = "" 

103 self.tool_calls: list[dict[str, Any]] = [] 

104 self.success: bool = False 

105 self.error: str | None = None 

106 

107 def __enter__(self) -> Self: 

108 if not is_braintrust_enabled(): 

109 return self 

110 

111 # Start the root span for this agent execution 

112 try: 

113 if start_span is None: 

114 return self 

115 span = start_span( 

116 name=f"agent:{self.issue_id}", 

117 type="task", 

118 metadata={ 

119 "issue_id": self.issue_id, 

120 "agent_id": self.agent_id, 

121 **self.metadata, 

122 }, 

123 ) 

124 self.span = span 

125 span.__enter__() # type: ignore[union-attr] 

126 except Exception as e: 

127 import sys 

128 

129 print(f"[braintrust] Failed to start span: {e}", file=sys.stderr) 

130 self.span = None 

131 return self 

132 

133 def __exit__( 

134 self, 

135 exc_type: type[BaseException] | None, 

136 exc_val: BaseException | None, 

137 exc_tb: TracebackType | None, 

138 ) -> None: 

139 if self.span is None: 

140 return 

141 

142 try: 

143 # Record the final output and metrics 

144 if exc_type is not None: 

145 self.error = str(exc_val) 

146 self.success = False 

147 

148 self.span.log( 

149 input=self.input_prompt, 

150 output=self.output_text, 

151 metadata={ 

152 "success": self.success, 

153 "error": self.error, 

154 "tool_calls_count": len(self.tool_calls), 

155 }, 

156 scores={"success": 1.0 if self.success else 0.0}, 

157 ) 

158 self.span.__exit__(exc_type, exc_val, exc_tb) 

159 # Flush to ensure data is sent before process exits 

160 flush_braintrust() 

161 except Exception as e: 

162 import sys 

163 

164 print(f"[braintrust] Failed to close span: {e}", file=sys.stderr) 

165 # Suppress the exception - tracing is best-effort 

166 

167 def log_input(self, prompt: str) -> None: 

168 """Log the initial user prompt.""" 

169 self.input_prompt = prompt 

170 

171 def log_message(self, message: SDKMessage) -> None: 

172 """Log a message from the Claude Agent SDK (for output/tool tracking).""" 

173 # Import SDK types at runtime (only when actually processing messages) 

174 from claude_agent_sdk import AssistantMessage, ResultMessage 

175 

176 try: 

177 if isinstance(message, AssistantMessage): 

178 self._handle_assistant_message(message) 

179 elif isinstance(message, ResultMessage): 

180 self._handle_result_message(message) 

181 except Exception: 

182 pass # Best-effort tracking 

183 

184 def _handle_assistant_message(self, message: AssistantMessage) -> None: 

185 """Process an assistant message, tracking text and tool calls.""" 

186 # Import SDK types at runtime (only when actually processing messages) 

187 from claude_agent_sdk import TextBlock, ToolResultBlock, ToolUseBlock 

188 

189 for block in message.content: 

190 if isinstance(block, TextBlock): 

191 self.output_text += block.text + "\n" 

192 

193 elif isinstance(block, ToolUseBlock): 

194 # Track tool calls for metadata (LLM spans are auto-traced by wrapper) 

195 tool_use_id = getattr(block, "id", f"tool_{len(self.tool_calls)}") 

196 tool_name = block.name 

197 tool_input = block.input 

198 

199 self.tool_calls.append( 

200 { 

201 "id": tool_use_id, 

202 "name": tool_name, 

203 "input": tool_input, 

204 } 

205 ) 

206 

207 elif isinstance(block, ToolResultBlock): 

208 pass # Tool results tracked in wrapper 

209 

210 def _handle_result_message(self, message: ResultMessage) -> None: 

211 """Process the final result message.""" 

212 self.output_text = message.result or self.output_text 

213 

214 def set_success(self, success: bool) -> None: 

215 """Mark the execution as successful or failed.""" 

216 self.success = success 

217 

218 def set_error(self, error: str) -> None: 

219 """Record an error message.""" 

220 self.error = error 

221 self.success = False 

222 

223 

224class BraintrustSpan: 

225 """Span implementation wrapping TracedAgentExecution. 

226 

227 This is an adapter that wraps TracedAgentExecution to match 

228 the TelemetrySpan protocol. It delegates all operations to 

229 the underlying TracedAgentExecution instance. 

230 """ 

231 

232 def __init__( 

233 self, issue_id: str, agent_id: str, metadata: dict[str, Any] | None = None 

234 ): 

235 self._tracer = TracedAgentExecution( 

236 issue_id=issue_id, 

237 agent_id=agent_id, 

238 metadata=metadata, 

239 ) 

240 

241 def __enter__(self) -> Self: 

242 self._tracer.__enter__() 

243 return self 

244 

245 def __exit__( 

246 self, 

247 exc_type: type[BaseException] | None, 

248 exc_val: BaseException | None, 

249 exc_tb: TracebackType | None, 

250 ) -> None: 

251 self._tracer.__exit__(exc_type, exc_val, exc_tb) 

252 

253 def log_input(self, prompt: str) -> None: 

254 self._tracer.log_input(prompt) 

255 

256 def log_message(self, message: object) -> None: 

257 self._tracer.log_message(message) 

258 

259 def set_success(self, success: bool) -> None: 

260 self._tracer.set_success(success) 

261 

262 def set_error(self, error: str) -> None: 

263 self._tracer.set_error(error) 

264 

265 

266class BraintrustProvider: 

267 """Telemetry provider wrapping Braintrust integration. 

268 

269 This provider wraps the existing TracedAgentExecution API, 

270 delegating to braintrust_integration.py for the actual tracing. 

271 

272 The provider is enabled when: 

273 - braintrust package is installed 

274 - BRAINTRUST_API_KEY environment variable is set 

275 

276 Usage: 

277 provider = BraintrustProvider() 

278 if provider.is_enabled(): 

279 with provider.create_span("task-123", {"agent_id": "agent-1"}): 

280 # Work is traced 

281 pass 

282 provider.flush() 

283 """ 

284 

285 def is_enabled(self) -> bool: 

286 """Check if Braintrust is available and configured.""" 

287 return is_braintrust_enabled() 

288 

289 def create_span( 

290 self, name: str, metadata: dict[str, Any] | None = None 

291 ) -> BraintrustSpan: 

292 """Create a span for tracing an agent execution. 

293 

294 Args: 

295 name: Span name (used as issue_id for TracedAgentExecution) 

296 metadata: Optional metadata dict. If 'agent_id' key is present, 

297 it's used for the agent_id parameter; otherwise defaults 

298 to 'unknown'. 

299 

300 Returns: 

301 A BraintrustSpan wrapping TracedAgentExecution 

302 """ 

303 metadata = metadata or {} 

304 agent_id = metadata.pop("agent_id", "unknown") 

305 return BraintrustSpan( 

306 issue_id=name, 

307 agent_id=str(agent_id), 

308 metadata=metadata if metadata else None, 

309 ) 

310 

311 def flush(self) -> None: 

312 """Flush pending Braintrust logs.""" 

313 flush_braintrust()