Coverage for llm_dataset_engine/utils/cost_tracker.py: 100%
67 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-15 18:04 +0200
1"""
2Cost tracking for LLM API calls.
4Provides accurate cost tracking with thread safety and detailed breakdowns.
5"""
7import threading
8from dataclasses import dataclass, field
9from decimal import Decimal
10from typing import Dict, Optional
12from llm_dataset_engine.core.models import CostEstimate
15@dataclass
16class CostEntry:
17 """Single cost tracking entry."""
19 tokens_in: int
20 tokens_out: int
21 cost: Decimal
22 model: str
23 timestamp: float
26class CostTracker:
27 """
28 Thread-safe cost tracker for LLM API usage.
30 Follows single responsibility principle for cost accounting.
31 """
33 def __init__(
34 self,
35 input_cost_per_1k: Optional[Decimal] = None,
36 output_cost_per_1k: Optional[Decimal] = None,
37 ):
38 """
39 Initialize cost tracker.
41 Args:
42 input_cost_per_1k: Input token cost per 1K tokens
43 output_cost_per_1k: Output token cost per 1K tokens
44 """
45 self.input_cost_per_1k = input_cost_per_1k or Decimal("0.0")
46 self.output_cost_per_1k = output_cost_per_1k or Decimal("0.0")
48 self._total_input_tokens = 0
49 self._total_output_tokens = 0
50 self._total_cost = Decimal("0.0")
51 self._entries: list[CostEntry] = []
52 self._stage_costs: Dict[str, Decimal] = {}
53 self._lock = threading.Lock()
55 def add(
56 self,
57 tokens_in: int,
58 tokens_out: int,
59 model: str,
60 timestamp: float,
61 stage: Optional[str] = None,
62 ) -> Decimal:
63 """
64 Add cost entry.
66 Args:
67 tokens_in: Input tokens used
68 tokens_out: Output tokens used
69 model: Model identifier
70 timestamp: Timestamp of request
71 stage: Optional stage name
73 Returns:
74 Cost for this entry
75 """
76 cost = self.calculate_cost(tokens_in, tokens_out)
78 with self._lock:
79 entry = CostEntry(
80 tokens_in=tokens_in,
81 tokens_out=tokens_out,
82 cost=cost,
83 model=model,
84 timestamp=timestamp,
85 )
86 self._entries.append(entry)
88 self._total_input_tokens += tokens_in
89 self._total_output_tokens += tokens_out
90 self._total_cost += cost
92 if stage:
93 self._stage_costs[stage] = (
94 self._stage_costs.get(stage, Decimal("0.0")) + cost
95 )
97 return cost
99 def calculate_cost(self, tokens_in: int, tokens_out: int) -> Decimal:
100 """
101 Calculate cost for given token counts.
103 Args:
104 tokens_in: Input tokens
105 tokens_out: Output tokens
107 Returns:
108 Total cost
109 """
110 input_cost = (Decimal(tokens_in) / 1000) * self.input_cost_per_1k
111 output_cost = (Decimal(tokens_out) / 1000) * self.output_cost_per_1k
112 return input_cost + output_cost
114 @property
115 def total_cost(self) -> Decimal:
116 """Get total accumulated cost."""
117 with self._lock:
118 return self._total_cost
120 @property
121 def total_tokens(self) -> int:
122 """Get total token count."""
123 with self._lock:
124 return self._total_input_tokens + self._total_output_tokens
126 @property
127 def input_tokens(self) -> int:
128 """Get total input tokens."""
129 with self._lock:
130 return self._total_input_tokens
132 @property
133 def output_tokens(self) -> int:
134 """Get total output tokens."""
135 with self._lock:
136 return self._total_output_tokens
138 def get_estimate(self, rows: int = 0) -> CostEstimate:
139 """
140 Get cost estimate.
142 Args:
143 rows: Number of rows processed
145 Returns:
146 CostEstimate object
147 """
148 with self._lock:
149 total_tokens = self._total_input_tokens + self._total_output_tokens
150 return CostEstimate(
151 total_cost=self._total_cost,
152 total_tokens=total_tokens,
153 input_tokens=self._total_input_tokens,
154 output_tokens=self._total_output_tokens,
155 rows=rows,
156 breakdown_by_stage=dict(self._stage_costs),
157 confidence="actual",
158 )
160 def reset(self) -> None:
161 """Reset all tracking."""
162 with self._lock:
163 self._total_input_tokens = 0
164 self._total_output_tokens = 0
165 self._total_cost = Decimal("0.0")
166 self._entries.clear()
167 self._stage_costs.clear()
169 def get_stage_costs(self) -> Dict[str, Decimal]:
170 """Get costs breakdown by stage."""
171 with self._lock:
172 return dict(self._stage_costs)