35
35
from opentelemetry .trace .status import Status , StatusCode
36
36
from wrapt import wrap_function_wrapper
37
37
38
- from mistralai .models . chat_completion import (
38
+ from mistralai .models import (
39
39
ChatCompletionResponse ,
40
- ChatCompletionResponseChoice ,
41
- ChatMessage ,
40
+ ChatCompletionChoice ,
41
+ AssistantMessage ,
42
+ UserMessage ,
43
+ SystemMessage ,
44
+ UsageInfo ,
45
+ EmbeddingResponse ,
42
46
)
43
- from mistralai .models .common import UsageInfo
44
- from mistralai .models .embeddings import EmbeddingResponse
45
47
46
48
logger = logging .getLogger (__name__ )
47
49
48
- _instruments = ("mistralai >= 0.2.0, < 1 " ,)
50
+ _instruments = ("mistralai >= 1.0.0 " ,)
49
51
50
52
WRAPPED_METHODS = [
51
53
{
52
- "method" : "chat" ,
54
+ "method" : "complete" ,
55
+ "module" : "chat" ,
53
56
"span_name" : "mistralai.chat" ,
54
57
"streaming" : False ,
55
58
},
56
59
{
57
- "method" : "chat_stream" ,
60
+ "method" : "stream" ,
61
+ "module" : "chat" ,
58
62
"span_name" : "mistralai.chat" ,
59
63
"streaming" : True ,
60
64
},
61
65
{
62
- "method" : "embeddings" ,
66
+ "method" : "create" ,
67
+ "module" : "embeddings" ,
63
68
"span_name" : "mistralai.embeddings" ,
64
69
"streaming" : False ,
65
70
},
@@ -92,7 +97,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
92
97
message .role ,
93
98
)
94
99
else :
95
- input = kwargs .get ("input" )
100
+ input = kwargs .get ("input" ) or kwargs . get ( "inputs" )
96
101
97
102
if isinstance (input , str ):
98
103
_set_span_attribute (
@@ -101,7 +106,7 @@ def _set_input_attributes(span, llm_request_type, to_wrap, kwargs):
101
106
_set_span_attribute (
102
107
span , f"{ SpanAttributes .LLM_PROMPTS } .0.content" , input
103
108
)
104
- else :
109
+ elif input :
105
110
for index , prompt in enumerate (input ):
106
111
_set_span_attribute (
107
112
span ,
@@ -205,20 +210,22 @@ def _accumulate_streaming_response(span, event_logger, llm_request_type, respons
205
210
for res in response :
206
211
yield res
207
212
208
- if res .model :
209
- accumulated_response .model = res .model
210
- if res .usage :
211
- accumulated_response .usage = res .usage
213
+ # Handle new CompletionEvent structure with .data attribute
214
+ chunk_data = res .data if hasattr (res , 'data' ) else res
215
+ if chunk_data .model :
216
+ accumulated_response .model = chunk_data .model
217
+ if chunk_data .usage :
218
+ accumulated_response .usage = chunk_data .usage
212
219
# Id is the same for all chunks, so it's safe to overwrite it every time
213
- if res .id :
214
- accumulated_response .id = res .id
220
+ if chunk_data .id :
221
+ accumulated_response .id = chunk_data .id
215
222
216
- for idx , choice in enumerate (res .choices ):
223
+ for idx , choice in enumerate (chunk_data .choices ):
217
224
if len (accumulated_response .choices ) <= idx :
218
225
accumulated_response .choices .append (
219
- ChatCompletionResponseChoice (
226
+ ChatCompletionChoice (
220
227
index = idx ,
221
- message = ChatMessage (role = "assistant" , content = "" ),
228
+ message = AssistantMessage (role = "assistant" , content = "" ),
222
229
finish_reason = None ,
223
230
)
224
231
)
@@ -247,20 +254,22 @@ async def _aaccumulate_streaming_response(
247
254
async for res in response :
248
255
yield res
249
256
250
- if res .model :
251
- accumulated_response .model = res .model
252
- if res .usage :
253
- accumulated_response .usage = res .usage
257
+ # Handle new CompletionEvent structure with .data attribute
258
+ chunk_data = res .data if hasattr (res , 'data' ) else res
259
+ if chunk_data .model :
260
+ accumulated_response .model = chunk_data .model
261
+ if chunk_data .usage :
262
+ accumulated_response .usage = chunk_data .usage
254
263
# Id is the same for all chunks, so it's safe to overwrite it every time
255
- if res .id :
256
- accumulated_response .id = res .id
264
+ if chunk_data .id :
265
+ accumulated_response .id = chunk_data .id
257
266
258
- for idx , choice in enumerate (res .choices ):
267
+ for idx , choice in enumerate (chunk_data .choices ):
259
268
if len (accumulated_response .choices ) <= idx :
260
269
accumulated_response .choices .append (
261
- ChatCompletionResponseChoice (
270
+ ChatCompletionChoice (
262
271
index = idx ,
263
- message = ChatMessage (role = "assistant" , content = "" ),
272
+ message = AssistantMessage (role = "assistant" , content = "" ),
264
273
finish_reason = None ,
265
274
)
266
275
)
@@ -287,9 +296,9 @@ def wrapper(wrapped, instance, args, kwargs):
287
296
288
297
289
298
def _llm_request_type_by_method (method_name ):
290
- if method_name == "chat " or method_name == "chat_stream " :
299
+ if method_name == "complete " or method_name == "stream " :
291
300
return LLMRequestTypeValues .CHAT
292
- elif method_name == "embeddings " :
301
+ elif method_name == "create " :
293
302
return LLMRequestTypeValues .EMBEDDING
294
303
else :
295
304
return LLMRequestTypeValues .UNKNOWN
@@ -301,7 +310,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
301
310
if method_wrapped == "mistralai.chat" :
302
311
messages = args [0 ] if len (args ) > 0 else kwargs .get ("messages" , [])
303
312
for message in messages :
304
- if isinstance (message , ChatMessage ):
313
+ if isinstance (message , ( UserMessage , AssistantMessage , SystemMessage ) ):
305
314
role = message .role
306
315
content = message .content
307
316
elif isinstance (message , dict ):
@@ -313,7 +322,7 @@ def _emit_message_events(method_wrapped: str, args, kwargs, event_logger):
313
322
314
323
# Handle embedding events
315
324
elif method_wrapped == "mistralai.embeddings" :
316
- embedding_input = args [0 ] if len (args ) > 0 else kwargs .get ("input" , [])
325
+ embedding_input = args [0 ] if len (args ) > 0 else ( kwargs .get ("input" ) or kwargs . get ( "inputs" , []) )
317
326
if isinstance (embedding_input , str ):
318
327
emit_event (MessageEvent (content = embedding_input , role = "user" ), event_logger )
319
328
elif isinstance (embedding_input , list ):
@@ -452,7 +461,7 @@ async def _awrap(
452
461
_handle_input (span , event_logger , args , kwargs , to_wrap )
453
462
454
463
if to_wrap .get ("streaming" ):
455
- response = wrapped (* args , ** kwargs )
464
+ response = await wrapped (* args , ** kwargs )
456
465
else :
457
466
response = await wrapped (* args , ** kwargs )
458
467
@@ -495,21 +504,23 @@ def _instrument(self, **kwargs):
495
504
496
505
for wrapped_method in WRAPPED_METHODS :
497
506
wrap_method = wrapped_method .get ("method" )
507
+ module_name = wrapped_method .get ("module" )
508
+ # Wrap sync methods on the class
498
509
wrap_function_wrapper (
499
- "mistralai.client " ,
500
- f"MistralClient .{ wrap_method } " ,
510
+ f "mistralai.{ module_name } " ,
511
+ f"{ module_name . capitalize () } .{ wrap_method } " ,
501
512
_wrap (tracer , event_logger , wrapped_method ),
502
513
)
514
+ # Wrap async methods on the class
503
515
wrap_function_wrapper (
504
- "mistralai.async_client " ,
505
- f"MistralAsyncClient. { wrap_method } " ,
516
+ f "mistralai.{ module_name } " ,
517
+ f"{ module_name . capitalize () } . { wrap_method } _async " ,
506
518
_awrap (tracer , event_logger , wrapped_method ),
507
519
)
508
520
509
521
def _uninstrument (self , ** kwargs ):
510
522
for wrapped_method in WRAPPED_METHODS :
511
- unwrap ("mistralai.client.MistralClient" , wrapped_method .get ("method" ))
512
- unwrap (
513
- "mistralai.async_client.MistralAsyncClient" ,
514
- wrapped_method .get ("method" ),
515
- )
523
+ wrap_method = wrapped_method .get ("method" )
524
+ module_name = wrapped_method .get ("module" )
525
+ unwrap (f"mistralai.{ module_name } .{ module_name .capitalize ()} " , wrap_method )
526
+ unwrap (f"mistralai.{ module_name } .{ module_name .capitalize ()} " , f"{ wrap_method } _async" )
0 commit comments