Skip to content

Commit 5ac6e6a

Browse files
authored
fix(mistralai): instrumentation for version 1.9+ compatibility (#3376)
1 parent a189d2d commit 5ac6e6a

24 files changed

+2671
-1969
lines changed

packages/opentelemetry-instrumentation-mistralai/opentelemetry/instrumentation/mistralai/__init__.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,36 @@
3535
from opentelemetry.trace.status import Status, StatusCode
3636
from wrapt import wrap_function_wrapper
3737

38-
from mistralai.models.chat_completion import (
38+
from mistralai.models import (
3939
ChatCompletionResponse,
40-
ChatCompletionResponseChoice,
41-
ChatMessage,
40+
ChatCompletionChoice,
41+
AssistantMessage,
42+
UserMessage,
43+
SystemMessage,
44+
UsageInfo,
45+
EmbeddingResponse,
4246
)
43-
from mistralai.models.common import UsageInfo
44-
from mistralai.models.embeddings import EmbeddingResponse
4547

4648
logger = logging.getLogger(__name__)
4749

48-
_instruments = ("mistralai >= 0.2.0, < 1",)
50+
_instruments = ("mistralai >= 1.0.0",)
4951

5052
WRAPPED_METHODS = [
5153
{
52-
"method": "chat",
54+
"method": "complete",
55+
"module": "chat",
5356
"span_name": "mistralai.chat",
5457
"streaming": False,
5558
},
5659
{
57-
"method": "chat_stream",
60+
"method": "stream",
61+
"module": "chat",
5862
"span_name": "mistralai.chat",
5963
"streaming": True,
6064
},
6165
{
62-
"method": "embeddings",
66+
"method": "create",
67+
"module": "embeddings",
6368
"span_name": "mistralai.embeddings",
6469
"streaming": False,
6570
},
@@ -92,7 +97,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
9297
message.role,
9398
)
9499
else:
95-
input = kwargs.get("input")
100+
input = kwargs.get("input") or kwargs.get("inputs")
96101

97102
if isinstance(input, str):
98103
_set_span_attribute(
@@ -101,7 +106,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
101106
_set_span_attribute(
102107
span, f"{SpanAttributes.LLM_PROMPTS}.0.content", input
103108
)
104-
else:
109+
elif input:
105110
for index, prompt in enumerate(input):
106111
_set_span_attribute(
107112
span,
@@ -205,20 +210,22 @@ def _accumulate_streaming_response(span, event_logger, llm_request_type, respons
205210
for res in response:
206211
yield res
207212

208-
if res.model:
209-
accumulated_response.model = res.model
210-
if res.usage:
211-
accumulated_response.usage = res.usage
213+
# Handle new CompletionEvent structure with .data attribute
214+
chunk_data = res.data if hasattr(res, 'data') else res
215+
if chunk_data.model:
216+
accumulated_response.model = chunk_data.model
217+
if chunk_data.usage:
218+
accumulated_response.usage = chunk_data.usage
212219
# Id is the same for all chunks, so it's safe to overwrite it every time
213-
if res.id:
214-
accumulated_response.id = res.id
220+
if chunk_data.id:
221+
accumulated_response.id = chunk_data.id
215222

216-
for idx, choice in enumerate(res.choices):
223+
for idx, choice in enumerate(chunk_data.choices):
217224
if len(accumulated_response.choices) <= idx:
218225
accumulated_response.choices.append(
219-
ChatCompletionResponseChoice(
226+
ChatCompletionChoice(
220227
index=idx,
221-
message=ChatMessage(role="assistant", content=""),
228+
message=AssistantMessage(role="assistant", content=""),
222229
finish_reason=None,
223230
)
224231
)
@@ -247,20 +254,22 @@ async def _aaccumulate_streaming_response(
247254
async for res in response:
248255
yield res
249256

250-
if res.model:
251-
accumulated_response.model = res.model
252-
if res.usage:
253-
accumulated_response.usage = res.usage
257+
# Handle new CompletionEvent structure with .data attribute
258+
chunk_data = res.data if hasattr(res, 'data') else res
259+
if chunk_data.model:
260+
accumulated_response.model = chunk_data.model
261+
if chunk_data.usage:
262+
accumulated_response.usage = chunk_data.usage
254263
# Id is the same for all chunks, so it's safe to overwrite it every time
255-
if res.id:
256-
accumulated_response.id = res.id
264+
if chunk_data.id:
265+
accumulated_response.id = chunk_data.id
257266

258-
for idx, choice in enumerate(res.choices):
267+
for idx, choice in enumerate(chunk_data.choices):
259268
if len(accumulated_response.choices) <= idx:
260269
accumulated_response.choices.append(
261-
ChatCompletionResponseChoice(
270+
ChatCompletionChoice(
262271
index=idx,
263-
message=ChatMessage(role="assistant", content=""),
272+
message=AssistantMessage(role="assistant", content=""),
264273
finish_reason=None,
265274
)
266275
)
@@ -287,9 +296,9 @@ def wrapper(wrapped, instance, args, kwargs):
287296

288297

289298
def _llm_request_type_by_method(method_name):
290-
if method_name == "chat" or method_name == "chat_stream":
299+
if method_name == "complete" or method_name == "stream":
291300
return LLMRequestTypeValues.CHAT
292-
elif method_name == "embeddings":
301+
elif method_name == "create":
293302
return LLMRequestTypeValues.EMBEDDING
294303
else:
295304
return LLMRequestTypeValues.UNKNOWN
@@ -301,7 +310,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
301310
if method_wrapped == "mistralai.chat":
302311
messages = args[0] if len(args) > 0 else kwargs.get("messages", [])
303312
for message in messages:
304-
if isinstance(message, ChatMessage):
313+
if isinstance(message, (UserMessage, AssistantMessage, SystemMessage)):
305314
role = message.role
306315
content = message.content
307316
elif isinstance(message, dict):
@@ -313,7 +322,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
313322

314323
# Handle embedding events
315324
elif method_wrapped == "mistralai.embeddings":
316-
embedding_input = args[0] if len(args) > 0 else kwargs.get("input", [])
325+
embedding_input = args[0] if len(args) > 0 else (kwargs.get("input") or kwargs.get("inputs", []))
317326
if isinstance(embedding_input, str):
318327
emit_event(MessageEvent(content=embedding_input, role="user"), event_logger)
319328
elif isinstance(embedding_input, list):
@@ -452,7 +461,7 @@ async def _awrap(
452461
_handle_input(span, event_logger, args, kwargs, to_wrap)
453462

454463
if to_wrap.get("streaming"):
455-
response = wrapped(*args, **kwargs)
464+
response = await wrapped(*args, **kwargs)
456465
else:
457466
response = await wrapped(*args, **kwargs)
458467

@@ -495,21 +504,23 @@ def _instrument(self, **kwargs):
495504

496505
for wrapped_method in WRAPPED_METHODS:
497506
wrap_method = wrapped_method.get("method")
507+
module_name = wrapped_method.get("module")
508+
# Wrap sync methods on the class
498509
wrap_function_wrapper(
499-
"mistralai.client",
500-
f"MistralClient.{wrap_method}",
510+
f"mistralai.{module_name}",
511+
f"{module_name.capitalize()}.{wrap_method}",
501512
_wrap(tracer, event_logger, wrapped_method),
502513
)
514+
# Wrap async methods on the class
503515
wrap_function_wrapper(
504-
"mistralai.async_client",
505-
f"MistralAsyncClient.{wrap_method}",
516+
f"mistralai.{module_name}",
517+
f"{module_name.capitalize()}.{wrap_method}_async",
506518
_awrap(tracer, event_logger, wrapped_method),
507519
)
508520

509521
def _uninstrument(self, **kwargs):
510522
for wrapped_method in WRAPPED_METHODS:
511-
unwrap("mistralai.client.MistralClient", wrapped_method.get("method"))
512-
unwrap(
513-
"mistralai.async_client.MistralAsyncClient",
514-
wrapped_method.get("method"),
515-
)
523+
wrap_method = wrapped_method.get("method")
524+
module_name = wrapped_method.get("module")
525+
unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", wrap_method)
526+
unwrap(f"mistralai.{module_name}.{module_name.capitalize()}", f"{wrap_method}_async")

0 commit comments

Comments
 (0)