11from contextlib import contextmanager
22from typing import Any
3- from unittest .mock import MagicMock , patch
3+ from unittest .mock import AsyncMock , MagicMock , patch
44
55import numpy as np
66import pytest
77from aviary .core import Message
88
99from lmi import cost_tracking_ctx
10- from lmi .cost_tracker import GLOBAL_COST_TRACKER
10+ from lmi .cost_tracker import GLOBAL_COST_TRACKER , TrackedStreamWrapper
1111from lmi .embeddings import LiteLLMEmbeddingModel
1212from lmi .llms import CommonLLMNames , LiteLLMModel
1313from lmi .utils import VCR_DEFAULT_MATCH_ON
@@ -164,24 +164,25 @@ class TestCostTrackerCallback:
164164 @pytest .mark .asyncio
165165 async def test_callback_succeeds (self ):
166166 mock_response = MagicMock (cost = 0.01 )
167+ callback_calls = []
167168
168- callback_calls : list [Any ] = []
169- GLOBAL_COST_TRACKER .add_callback (callback_calls .append )
169+ async def async_callback (response ): # noqa: RUF029
170+ callback_calls .append (response )
171+
172+ GLOBAL_COST_TRACKER .add_callback (async_callback )
170173
171174 with (
172175 cost_tracking_ctx (),
173176 patch ("litellm.cost_calculator.completion_cost" , return_value = 0.01 ),
174177 ):
175- GLOBAL_COST_TRACKER .record (mock_response )
178+ await GLOBAL_COST_TRACKER .record (mock_response )
176179
177180 assert len (callback_calls ) == 1
178181 assert callback_calls [0 ] == mock_response
179-
180182 assert GLOBAL_COST_TRACKER .lifetime_cost_usd > 0
181183
182184 @pytest .mark .asyncio
183185 async def test_callback_failure_does_not_break_tracker (self , caplog ):
184- """Test that a failing callback doesn't break the cost tracker."""
185186 mock_response = MagicMock (cost = 0.01 )
186187 failing_callback = MagicMock (side_effect = Exception ("Callback failed" ))
187188 GLOBAL_COST_TRACKER .add_callback (failing_callback )
@@ -190,7 +191,7 @@ async def test_callback_failure_does_not_break_tracker(self, caplog):
190191 cost_tracking_ctx (),
191192 patch ("litellm.cost_calculator.completion_cost" , return_value = 0.01 ),
192193 ):
193- GLOBAL_COST_TRACKER .record (mock_response )
194+ await GLOBAL_COST_TRACKER .record (mock_response )
194195
195196 failing_callback .assert_called_once_with (mock_response )
196197
@@ -200,7 +201,6 @@ async def test_callback_failure_does_not_break_tracker(self, caplog):
200201
201202 @pytest .mark .asyncio
202203 async def test_multiple_callbacks_with_one_failing (self , caplog ):
203- """Test that one failing callback doesn't prevent other callbacks from running."""
204204 mock_response = MagicMock (cost = 0.01 )
205205 failing_callback = MagicMock (side_effect = Exception ("Callback failed" ))
206206 succeeding_callback = MagicMock ()
@@ -212,10 +212,36 @@ async def test_multiple_callbacks_with_one_failing(self, caplog):
212212 cost_tracking_ctx (),
213213 patch ("litellm.cost_calculator.completion_cost" , return_value = 0.01 ),
214214 ):
215- GLOBAL_COST_TRACKER .record (mock_response )
215+ await GLOBAL_COST_TRACKER .record (mock_response )
216216
217217 failing_callback .assert_called_once_with (mock_response )
218218 succeeding_callback .assert_called_once_with (mock_response )
219219
220220 assert "Callback failed during cost tracking" in caplog .text
221221 assert GLOBAL_COST_TRACKER .lifetime_cost_usd > 0
222+
223+ @pytest .mark .asyncio
224+ async def test_async_context_with_stream_wrapper (self ):
225+ mock_stream = MagicMock ()
226+ mock_response = MagicMock (cost = 0.01 )
227+ mock_stream .__anext__ = AsyncMock (return_value = mock_response )
228+
229+ wrapper = TrackedStreamWrapper (mock_stream )
230+
231+ callback_calls = []
232+
233+ async def async_callback (response ): # noqa: RUF029
234+ callback_calls .append (response )
235+
236+ GLOBAL_COST_TRACKER .add_callback (async_callback )
237+
238+ with (
239+ cost_tracking_ctx (),
240+ patch ("litellm.cost_calculator.completion_cost" , return_value = 0.01 ),
241+ ):
242+ result = await anext (wrapper )
243+
244+ assert result == mock_response
245+ assert len (callback_calls ) == 1
246+ assert callback_calls [0 ] == mock_response
247+ assert GLOBAL_COST_TRACKER .lifetime_cost_usd > 0
0 commit comments