Coverage for llm_dataset_engine/utils/cost_tracker.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-15 18:04 +0200

1""" 

2Cost tracking for LLM API calls. 

3 

4Provides accurate cost tracking with thread safety and detailed breakdowns. 

5""" 

6 

7import threading 

8from dataclasses import dataclass, field 

9from decimal import Decimal 

10from typing import Dict, Optional 

11 

12from llm_dataset_engine.core.models import CostEstimate 

13 

14 

15@dataclass 

16class CostEntry: 

17 """Single cost tracking entry.""" 

18 

19 tokens_in: int 

20 tokens_out: int 

21 cost: Decimal 

22 model: str 

23 timestamp: float 

24 

25 

26class CostTracker: 

27 """ 

28 Thread-safe cost tracker for LLM API usage. 

29  

30 Follows single responsibility principle for cost accounting. 

31 """ 

32 

33 def __init__( 

34 self, 

35 input_cost_per_1k: Optional[Decimal] = None, 

36 output_cost_per_1k: Optional[Decimal] = None, 

37 ): 

38 """ 

39 Initialize cost tracker. 

40 

41 Args: 

42 input_cost_per_1k: Input token cost per 1K tokens 

43 output_cost_per_1k: Output token cost per 1K tokens 

44 """ 

45 self.input_cost_per_1k = input_cost_per_1k or Decimal("0.0") 

46 self.output_cost_per_1k = output_cost_per_1k or Decimal("0.0") 

47 

48 self._total_input_tokens = 0 

49 self._total_output_tokens = 0 

50 self._total_cost = Decimal("0.0") 

51 self._entries: list[CostEntry] = [] 

52 self._stage_costs: Dict[str, Decimal] = {} 

53 self._lock = threading.Lock() 

54 

55 def add( 

56 self, 

57 tokens_in: int, 

58 tokens_out: int, 

59 model: str, 

60 timestamp: float, 

61 stage: Optional[str] = None, 

62 ) -> Decimal: 

63 """ 

64 Add cost entry. 

65 

66 Args: 

67 tokens_in: Input tokens used 

68 tokens_out: Output tokens used 

69 model: Model identifier 

70 timestamp: Timestamp of request 

71 stage: Optional stage name 

72 

73 Returns: 

74 Cost for this entry 

75 """ 

76 cost = self.calculate_cost(tokens_in, tokens_out) 

77 

78 with self._lock: 

79 entry = CostEntry( 

80 tokens_in=tokens_in, 

81 tokens_out=tokens_out, 

82 cost=cost, 

83 model=model, 

84 timestamp=timestamp, 

85 ) 

86 self._entries.append(entry) 

87 

88 self._total_input_tokens += tokens_in 

89 self._total_output_tokens += tokens_out 

90 self._total_cost += cost 

91 

92 if stage: 

93 self._stage_costs[stage] = ( 

94 self._stage_costs.get(stage, Decimal("0.0")) + cost 

95 ) 

96 

97 return cost 

98 

99 def calculate_cost(self, tokens_in: int, tokens_out: int) -> Decimal: 

100 """ 

101 Calculate cost for given token counts. 

102 

103 Args: 

104 tokens_in: Input tokens 

105 tokens_out: Output tokens 

106 

107 Returns: 

108 Total cost 

109 """ 

110 input_cost = (Decimal(tokens_in) / 1000) * self.input_cost_per_1k 

111 output_cost = (Decimal(tokens_out) / 1000) * self.output_cost_per_1k 

112 return input_cost + output_cost 

113 

114 @property 

115 def total_cost(self) -> Decimal: 

116 """Get total accumulated cost.""" 

117 with self._lock: 

118 return self._total_cost 

119 

120 @property 

121 def total_tokens(self) -> int: 

122 """Get total token count.""" 

123 with self._lock: 

124 return self._total_input_tokens + self._total_output_tokens 

125 

126 @property 

127 def input_tokens(self) -> int: 

128 """Get total input tokens.""" 

129 with self._lock: 

130 return self._total_input_tokens 

131 

132 @property 

133 def output_tokens(self) -> int: 

134 """Get total output tokens.""" 

135 with self._lock: 

136 return self._total_output_tokens 

137 

138 def get_estimate(self, rows: int = 0) -> CostEstimate: 

139 """ 

140 Get cost estimate. 

141 

142 Args: 

143 rows: Number of rows processed 

144 

145 Returns: 

146 CostEstimate object 

147 """ 

148 with self._lock: 

149 total_tokens = self._total_input_tokens + self._total_output_tokens 

150 return CostEstimate( 

151 total_cost=self._total_cost, 

152 total_tokens=total_tokens, 

153 input_tokens=self._total_input_tokens, 

154 output_tokens=self._total_output_tokens, 

155 rows=rows, 

156 breakdown_by_stage=dict(self._stage_costs), 

157 confidence="actual", 

158 ) 

159 

160 def reset(self) -> None: 

161 """Reset all tracking.""" 

162 with self._lock: 

163 self._total_input_tokens = 0 

164 self._total_output_tokens = 0 

165 self._total_cost = Decimal("0.0") 

166 self._entries.clear() 

167 self._stage_costs.clear() 

168 

169 def get_stage_costs(self) -> Dict[str, Decimal]: 

170 """Get costs breakdown by stage.""" 

171 with self._lock: 

172 return dict(self._stage_costs) 

173