Skip to content

Commit dbc54f8

Browse files
AndyJCaijamesbraza
andauthored
Make the callbacks in cost_tracker async (#335)
Co-authored-by: James Braza <[email protected]>
1 parent af57db5 commit dbc54f8

File tree

2 files changed

+43
-26
lines changed

2 files changed

+43
-26
lines changed

packages/lmi/src/lmi/cost_tracker.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Awaitable, Callable
44
from contextlib import contextmanager
55
from functools import wraps
6-
from typing import Any, ParamSpec, TypeVar
6+
from typing import ParamSpec, TypeVar
77

88
import litellm
99

@@ -20,15 +20,12 @@ def __init__(self):
2020
self.enabled = contextvars.ContextVar[bool]("track_costs", default=False)
2121
# Not a contextvar because I can't imagine a scenario where you'd want more fine-grained control
2222
self.report_every_usd = 1.0
23-
self._callbacks: list[Callable[[LLMResponse], Any]] = []
23+
self._callbacks: list[Callable[[LLMResponse], Awaitable]] = []
2424

25-
def add_callback(self, callback: Callable[[LLMResponse], Any]) -> None:
25+
def add_callback(self, callback: Callable[[LLMResponse], Awaitable]) -> None:
2626
self._callbacks.append(callback)
2727

28-
def record(
29-
self,
30-
response: LLMResponse,
31-
) -> None:
28+
async def record(self, response: LLMResponse) -> None:
3229
self.lifetime_cost_usd += litellm.cost_calculator.completion_cost(
3330
completion_response=response
3431
)
@@ -39,7 +36,7 @@ def record(
3936

4037
for callback in self._callbacks:
4138
try:
42-
callback(response)
39+
await callback(response)
4340
except Exception as e:
4441
logger.warning(
4542
f"Callback failed during cost tracking: {e}", exc_info=True
@@ -100,7 +97,7 @@ async def api_call(...) -> litellm.ModelResponse:
10097
async def wrapped_func(*args, **kwargs):
10198
response = await func(*args, **kwargs)
10299
if GLOBAL_COST_TRACKER.enabled.get():
103-
GLOBAL_COST_TRACKER.record(response)
100+
await GLOBAL_COST_TRACKER.record(response)
104101
return response
105102

106103
return wrapped_func
@@ -140,16 +137,10 @@ def __iter__(self):
140137
def __aiter__(self):
141138
return self
142139

143-
def __next__(self):
144-
response = next(self.stream)
145-
if GLOBAL_COST_TRACKER.enabled.get():
146-
GLOBAL_COST_TRACKER.record(response)
147-
return response
148-
149140
async def __anext__(self):
150141
response = await self.stream.__anext__()
151142
if GLOBAL_COST_TRACKER.enabled.get():
152-
GLOBAL_COST_TRACKER.record(response)
143+
await GLOBAL_COST_TRACKER.record(response)
153144
return response
154145

155146

packages/lmi/tests/test_cost_tracking.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from contextlib import contextmanager
22
from typing import Any
3-
from unittest.mock import MagicMock, patch
3+
from unittest.mock import AsyncMock, MagicMock, patch
44

55
import numpy as np
66
import pytest
77
from aviary.core import Message
88

99
from lmi import cost_tracking_ctx
10-
from lmi.cost_tracker import GLOBAL_COST_TRACKER
10+
from lmi.cost_tracker import GLOBAL_COST_TRACKER, TrackedStreamWrapper
1111
from lmi.embeddings import LiteLLMEmbeddingModel
1212
from lmi.llms import CommonLLMNames, LiteLLMModel
1313
from lmi.utils import VCR_DEFAULT_MATCH_ON
@@ -164,24 +164,25 @@ class TestCostTrackerCallback:
164164
@pytest.mark.asyncio
165165
async def test_callback_succeeds(self):
166166
mock_response = MagicMock(cost=0.01)
167+
callback_calls = []
167168

168-
callback_calls: list[Any] = []
169-
GLOBAL_COST_TRACKER.add_callback(callback_calls.append)
169+
async def async_callback(response): # noqa: RUF029
170+
callback_calls.append(response)
171+
172+
GLOBAL_COST_TRACKER.add_callback(async_callback)
170173

171174
with (
172175
cost_tracking_ctx(),
173176
patch("litellm.cost_calculator.completion_cost", return_value=0.01),
174177
):
175-
GLOBAL_COST_TRACKER.record(mock_response)
178+
await GLOBAL_COST_TRACKER.record(mock_response)
176179

177180
assert len(callback_calls) == 1
178181
assert callback_calls[0] == mock_response
179-
180182
assert GLOBAL_COST_TRACKER.lifetime_cost_usd > 0
181183

182184
@pytest.mark.asyncio
183185
async def test_callback_failure_does_not_break_tracker(self, caplog):
184-
"""Test that a failing callback doesn't break the cost tracker."""
185186
mock_response = MagicMock(cost=0.01)
186187
failing_callback = MagicMock(side_effect=Exception("Callback failed"))
187188
GLOBAL_COST_TRACKER.add_callback(failing_callback)
@@ -190,7 +191,7 @@ async def test_callback_failure_does_not_break_tracker(self, caplog):
190191
cost_tracking_ctx(),
191192
patch("litellm.cost_calculator.completion_cost", return_value=0.01),
192193
):
193-
GLOBAL_COST_TRACKER.record(mock_response)
194+
await GLOBAL_COST_TRACKER.record(mock_response)
194195

195196
failing_callback.assert_called_once_with(mock_response)
196197

@@ -200,7 +201,6 @@ async def test_callback_failure_does_not_break_tracker(self, caplog):
200201

201202
@pytest.mark.asyncio
202203
async def test_multiple_callbacks_with_one_failing(self, caplog):
203-
"""Test that one failing callback doesn't prevent other callbacks from running."""
204204
mock_response = MagicMock(cost=0.01)
205205
failing_callback = MagicMock(side_effect=Exception("Callback failed"))
206206
succeeding_callback = MagicMock()
@@ -212,10 +212,36 @@ async def test_multiple_callbacks_with_one_failing(self, caplog):
212212
cost_tracking_ctx(),
213213
patch("litellm.cost_calculator.completion_cost", return_value=0.01),
214214
):
215-
GLOBAL_COST_TRACKER.record(mock_response)
215+
await GLOBAL_COST_TRACKER.record(mock_response)
216216

217217
failing_callback.assert_called_once_with(mock_response)
218218
succeeding_callback.assert_called_once_with(mock_response)
219219

220220
assert "Callback failed during cost tracking" in caplog.text
221221
assert GLOBAL_COST_TRACKER.lifetime_cost_usd > 0
222+
223+
@pytest.mark.asyncio
224+
async def test_async_context_with_stream_wrapper(self):
225+
mock_stream = MagicMock()
226+
mock_response = MagicMock(cost=0.01)
227+
mock_stream.__anext__ = AsyncMock(return_value=mock_response)
228+
229+
wrapper = TrackedStreamWrapper(mock_stream)
230+
231+
callback_calls = []
232+
233+
async def async_callback(response): # noqa: RUF029
234+
callback_calls.append(response)
235+
236+
GLOBAL_COST_TRACKER.add_callback(async_callback)
237+
238+
with (
239+
cost_tracking_ctx(),
240+
patch("litellm.cost_calculator.completion_cost", return_value=0.01),
241+
):
242+
result = await anext(wrapper)
243+
244+
assert result == mock_response
245+
assert len(callback_calls) == 1
246+
assert callback_calls[0] == mock_response
247+
assert GLOBAL_COST_TRACKER.lifetime_cost_usd > 0

0 commit comments

Comments
 (0)