Skip to content

Commit beea9b7

Browse files
authored
openai: Fix missing parse attribute on StreamWrapper (#82)
* openai: Fix missing parse attribute on StreamWrapper * openai: Remove reduntant unit tests for raw responses and add explicit parsing to existing tests * openai: Remove is_raw_response attribute from StreamWrapper
1 parent 9b7281e commit beea9b7

File tree

3 files changed

+34
-12
lines changed

3 files changed

+34
-12
lines changed

instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/__init__.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def _chat_completion_wrapper(self, wrapped, instance, args, kwargs):
196196
_record_operation_duration_metric(self.operation_duration_metric, error_attributes, start_time)
197197
raise
198198

199-
is_raw_response = _is_raw_response(result)
200199
if kwargs.get("stream"):
201200
return StreamWrapper(
202201
stream=result,
@@ -208,12 +207,12 @@ def _chat_completion_wrapper(self, wrapped, instance, args, kwargs):
208207
start_time=start_time,
209208
token_usage_metric=self.token_usage_metric,
210209
operation_duration_metric=self.operation_duration_metric,
211-
is_raw_response=is_raw_response,
212210
)
213211

214212
logger.debug(f"openai.resources.chat.completions.Completions.create result: {result}")
215213

216214
# if the caller is using with_raw_response we need to parse the output to get the response class we expect
215+
is_raw_response = _is_raw_response(result)
217216
if is_raw_response:
218217
result = result.parse()
219218
response_attributes = _get_attributes_from_response(
@@ -271,7 +270,6 @@ async def _async_chat_completion_wrapper(self, wrapped, instance, args, kwargs):
271270
_record_operation_duration_metric(self.operation_duration_metric, error_attributes, start_time)
272271
raise
273272

274-
is_raw_response = _is_raw_response(result)
275273
if kwargs.get("stream"):
276274
return StreamWrapper(
277275
stream=result,
@@ -283,12 +281,12 @@ async def _async_chat_completion_wrapper(self, wrapped, instance, args, kwargs):
283281
start_time=start_time,
284282
token_usage_metric=self.token_usage_metric,
285283
operation_duration_metric=self.operation_duration_metric,
286-
is_raw_response=is_raw_response,
287284
)
288285

289286
logger.debug(f"openai.resources.chat.completions.AsyncCompletions.create result: {result}")
290287

291288
# if the caller is using with_raw_response we need to parse the output to get the response class we expect
289+
is_raw_response = _is_raw_response(result)
292290
if is_raw_response:
293291
result = result.parse()
294292
response_attributes = _get_attributes_from_response(

instrumentation/elastic-opentelemetry-instrumentation-openai/src/opentelemetry/instrumentation/openai/wrappers.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __init__(
4747
start_time: float,
4848
token_usage_metric: Histogram,
4949
operation_duration_metric: Histogram,
50-
is_raw_response: bool,
5150
):
5251
# we need to wrap the original response even in case of raw_responses
5352
super().__init__(stream)
@@ -60,7 +59,6 @@ def __init__(
6059
self.token_usage_metric = token_usage_metric
6160
self.operation_duration_metric = operation_duration_metric
6261
self.start_time = start_time
63-
self.is_raw_response = is_raw_response
6462

6563
self.response_id = None
6664
self.model = None
@@ -125,8 +123,6 @@ def __exit__(self, exc_type, exc_value, traceback):
125123
def __iter__(self):
126124
stream = self.__wrapped__
127125
try:
128-
if self.is_raw_response:
129-
stream = stream.parse()
130126
for chunk in stream:
131127
self.process_chunk(chunk)
132128
yield chunk
@@ -145,12 +141,34 @@ async def __aexit__(self, exc_type, exc_value, traceback):
145141
async def __aiter__(self):
146142
stream = self.__wrapped__
147143
try:
148-
if self.is_raw_response:
149-
stream = stream.parse()
150144
async for chunk in stream:
151145
self.process_chunk(chunk)
152146
yield chunk
153147
except Exception as exc:
154148
self.end(exc)
155149
raise
156150
self.end()
151+
152+
def parse(self):
153+
"""
154+
Handles direct parse() call on the client in order to maintain instrumentation on the parsed iterator.
155+
"""
156+
parsed_iterator = self.__wrapped__.parse()
157+
158+
parsed_wrapper = StreamWrapper(
159+
stream=parsed_iterator,
160+
span=self.span,
161+
span_attributes=self.span_attributes,
162+
capture_message_content=self.capture_message_content,
163+
event_attributes=self.event_attributes,
164+
event_logger=self.event_logger,
165+
start_time=self.start_time,
166+
token_usage_metric=self.token_usage_metric,
167+
operation_duration_metric=self.operation_duration_metric,
168+
)
169+
170+
# Handle original sync/async iterators accordingly
171+
if hasattr(parsed_iterator, "__aiter__"):
172+
return parsed_wrapper.__aiter__()
173+
174+
return parsed_wrapper.__iter__()

instrumentation/elastic-opentelemetry-instrumentation-openai/tests/test_chat_completions.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1171,10 +1171,13 @@ def test_chat_stream_with_raw_response(default_openai_env, trace_exporter, metri
11711171
}
11721172
]
11731173

1174-
chat_completion = client.chat.completions.with_raw_response.create(
1174+
raw_response = client.chat.completions.with_raw_response.create(
11751175
model=TEST_CHAT_MODEL, messages=messages, stream=True
11761176
)
11771177

1178+
# Explicit parse of the raw response
1179+
chat_completion = raw_response.parse()
1180+
11781181
chunks = [chunk.choices[0].delta.content or "" for chunk in chat_completion if chunk.choices]
11791182
assert "".join(chunks) == "Atlantic Ocean"
11801183

@@ -2226,10 +2229,13 @@ async def test_chat_async_stream_with_raw_response(default_openai_env, trace_exp
22262229
}
22272230
]
22282231

2229-
chat_completion = await client.chat.completions.with_raw_response.create(
2232+
raw_response = await client.chat.completions.with_raw_response.create(
22302233
model=TEST_CHAT_MODEL, messages=messages, stream=True
22312234
)
22322235

2236+
# Explicit parse of the raw response
2237+
chat_completion = raw_response.parse()
2238+
22332239
chunks = [chunk.choices[0].delta.content or "" async for chunk in chat_completion if chunk.choices]
22342240
assert "".join(chunks) == "Atlantic Ocean"
22352241

0 commit comments

Comments
 (0)