11
11
emit_response_events ,
12
12
)
13
13
from opentelemetry .instrumentation .cohere .span_utils import (
14
- set_input_attributes ,
15
- set_response_attributes ,
14
+ set_input_content_attributes ,
15
+ set_response_content_attributes ,
16
16
set_span_request_attributes ,
17
+ set_span_response_attributes ,
18
+ )
19
+ from opentelemetry .instrumentation .cohere .streaming import (
20
+ process_chat_v1_streaming_response ,
21
+ aprocess_chat_v1_streaming_response ,
22
+ process_chat_v2_streaming_response ,
23
+ aprocess_chat_v2_streaming_response ,
17
24
)
18
25
from opentelemetry .instrumentation .cohere .utils import dont_throw , should_emit_events
19
26
from opentelemetry .instrumentation .cohere .version import __version__
27
34
LLMRequestTypeValues ,
28
35
SpanAttributes ,
29
36
)
30
- from opentelemetry .trace import SpanKind , Tracer , get_tracer
37
+ from opentelemetry .trace import SpanKind , Status , StatusCode , Tracer , get_tracer , use_span
31
38
from wrapt import wrap_function_wrapper
32
39
33
40
logger = logging .getLogger (__name__ )
36
43
37
44
WRAPPED_METHODS = [
38
45
{
46
+ "module" : "cohere.client" ,
39
47
"object" : "Client" ,
40
48
"method" : "generate" ,
41
49
"span_name" : "cohere.completion" ,
42
50
},
43
51
{
52
+ "module" : "cohere.client" ,
44
53
"object" : "Client" ,
45
54
"method" : "chat" ,
46
55
"span_name" : "cohere.chat" ,
47
56
},
48
57
{
58
+ "module" : "cohere.client" ,
59
+ "object" : "Client" ,
60
+ "method" : "chat_stream" ,
61
+ "span_name" : "cohere.chat" ,
62
+ "stream_process_func" : process_chat_v1_streaming_response ,
63
+ },
64
+ {
65
+ "module" : "cohere.client" ,
66
+ "object" : "Client" ,
67
+ "method" : "rerank" ,
68
+ "span_name" : "cohere.rerank" ,
69
+ },
70
+ {
71
+ "module" : "cohere.client" ,
49
72
"object" : "Client" ,
73
+ "method" : "embed" ,
74
+ "span_name" : "cohere.embed" ,
75
+ },
76
+ {
77
+ "module" : "cohere.client_v2" ,
78
+ "object" : "ClientV2" ,
79
+ "method" : "chat" ,
80
+ "span_name" : "cohere.chat" ,
81
+ },
82
+ {
83
+ "module" : "cohere.client_v2" ,
84
+ "object" : "ClientV2" ,
85
+ "method" : "chat_stream" ,
86
+ "span_name" : "cohere.chat" ,
87
+ "stream_process_func" : process_chat_v2_streaming_response ,
88
+ },
89
+ {
90
+ "module" : "cohere.client_v2" ,
91
+ "object" : "ClientV2" ,
50
92
"method" : "rerank" ,
51
93
"span_name" : "cohere.rerank" ,
52
94
},
95
+ {
96
+ "module" : "cohere.client_v2" ,
97
+ "object" : "ClientV2" ,
98
+ "method" : "embed" ,
99
+ "span_name" : "cohere.embed" ,
100
+ },
101
+ # Async methods that return AsyncIterator must be wrapped with sync wrapper
102
+ {
103
+ "module" : "cohere.client" ,
104
+ "object" : "AsyncClient" ,
105
+ "method" : "chat_stream" ,
106
+ "span_name" : "cohere.chat" ,
107
+ "stream_process_func" : aprocess_chat_v1_streaming_response ,
108
+ },
109
+ {
110
+ "module" : "cohere.client_v2" ,
111
+ "object" : "AsyncClientV2" ,
112
+ "method" : "chat_stream" ,
113
+ "span_name" : "cohere.chat" ,
114
+ "stream_process_func" : aprocess_chat_v2_streaming_response ,
115
+ },
116
+ ]
117
+
118
+ WRAPPED_AMETHODS = [
119
+ {
120
+ "module" : "cohere.client" ,
121
+ "object" : "AsyncClient" ,
122
+ "method" : "generate" ,
123
+ "span_name" : "cohere.completion" ,
124
+ },
125
+ {
126
+ "module" : "cohere.client" ,
127
+ "object" : "AsyncClient" ,
128
+ "method" : "chat" ,
129
+ "span_name" : "cohere.chat" ,
130
+ },
131
+ {
132
+ "module" : "cohere.client" ,
133
+ "object" : "AsyncClient" ,
134
+ "method" : "rerank" ,
135
+ "span_name" : "cohere.rerank" ,
136
+ },
137
+ {
138
+ "module" : "cohere.client" ,
139
+ "object" : "AsyncClient" ,
140
+ "method" : "embed" ,
141
+ "span_name" : "cohere.embed" ,
142
+ },
143
+ {
144
+ "module" : "cohere.client_v2" ,
145
+ "object" : "AsyncClientV2" ,
146
+ "method" : "chat" ,
147
+ "span_name" : "cohere.chat" ,
148
+ },
149
+ {
150
+ "module" : "cohere.client_v2" ,
151
+ "object" : "AsyncClientV2" ,
152
+ "method" : "rerank" ,
153
+ "span_name" : "cohere.rerank" ,
154
+ },
155
+ {
156
+ "module" : "cohere.client_v2" ,
157
+ "object" : "AsyncClientV2" ,
158
+ "method" : "embed" ,
159
+ "span_name" : "cohere.embed" ,
160
+ },
53
161
]
54
162
55
163
@@ -66,30 +174,30 @@ def wrapper(wrapped, instance, args, kwargs):
66
174
67
175
68
176
def _llm_request_type_by_method (method_name ):
69
- if method_name == "chat" :
177
+ if method_name in [ "chat" , "chat_stream" ] :
70
178
return LLMRequestTypeValues .CHAT
71
- elif method_name == "generate" :
179
+ elif method_name in [ "generate" , "generate_stream" ] :
72
180
return LLMRequestTypeValues .COMPLETION
73
181
elif method_name == "rerank" :
74
182
return LLMRequestTypeValues .RERANK
183
+ elif method_name == "embed" :
184
+ return LLMRequestTypeValues .EMBEDDING
75
185
else :
76
186
return LLMRequestTypeValues .UNKNOWN
77
187
78
188
79
189
@dont_throw
80
- def _handle_input (span , event_logger , llm_request_type , kwargs ):
190
+ def _handle_input_content (span , event_logger , llm_request_type , kwargs ):
191
+ set_input_content_attributes (span , llm_request_type , kwargs )
81
192
if should_emit_events ():
82
193
emit_input_event (event_logger , llm_request_type , kwargs )
83
- else :
84
- set_input_attributes (span , llm_request_type , kwargs )
85
194
86
195
87
196
@dont_throw
88
- def _handle_response (span , event_logger , llm_request_type , response ):
197
+ def _handle_response_content (span , event_logger , llm_request_type , response ):
198
+ set_response_content_attributes (span , llm_request_type , response )
89
199
if should_emit_events ():
90
200
emit_response_events (event_logger , llm_request_type , response )
91
- else :
92
- set_response_attributes (span , llm_request_type , response )
93
201
94
202
95
203
@_with_tracer_wrapper
@@ -108,6 +216,55 @@ def _wrap(
108
216
):
109
217
return wrapped (* args , ** kwargs )
110
218
219
+ name = to_wrap .get ("span_name" )
220
+ llm_request_type = _llm_request_type_by_method (to_wrap .get ("method" ))
221
+ span = tracer .start_span (
222
+ name ,
223
+ kind = SpanKind .CLIENT ,
224
+ attributes = {
225
+ SpanAttributes .LLM_SYSTEM : "Cohere" ,
226
+ SpanAttributes .LLM_REQUEST_TYPE : llm_request_type .value ,
227
+ },
228
+ )
229
+
230
+ with use_span (span , end_on_exit = False ):
231
+ set_span_request_attributes (span , kwargs )
232
+ _handle_input_content (span , event_logger , llm_request_type , kwargs )
233
+
234
+ try :
235
+ response = wrapped (* args , ** kwargs )
236
+ except Exception as e :
237
+ if span .is_recording ():
238
+ span .set_status (Status (StatusCode .ERROR , str (e )))
239
+ span .record_exception (e )
240
+ span .end ()
241
+ raise
242
+
243
+ if to_wrap .get ("stream_process_func" ):
244
+ return to_wrap .get ("stream_process_func" )(span , event_logger , llm_request_type , response )
245
+
246
+ set_span_response_attributes (span , response )
247
+ _handle_response_content (span , event_logger , llm_request_type , response )
248
+ span .end ()
249
+ return response
250
+
251
+
252
+ @_with_tracer_wrapper
253
+ async def _awrap (
254
+ tracer : Tracer ,
255
+ event_logger : Union [EventLogger , None ],
256
+ to_wrap ,
257
+ wrapped ,
258
+ instance ,
259
+ args ,
260
+ kwargs ,
261
+ ):
262
+ """Instruments and calls every function defined in TO_WRAP."""
263
+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ) or context_api .get_value (
264
+ SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
265
+ ):
266
+ return await wrapped (* args , ** kwargs )
267
+
111
268
name = to_wrap .get ("span_name" )
112
269
llm_request_type = _llm_request_type_by_method (to_wrap .get ("method" ))
113
270
with tracer .start_as_current_span (
@@ -119,12 +276,19 @@ def _wrap(
119
276
},
120
277
) as span :
121
278
set_span_request_attributes (span , kwargs )
122
- _handle_input (span , event_logger , llm_request_type , kwargs )
279
+ _handle_input_content (span , event_logger , llm_request_type , kwargs )
123
280
124
- response = wrapped (* args , ** kwargs )
281
+ try :
282
+ response = await wrapped (* args , ** kwargs )
283
+ except Exception as e :
284
+ if span .is_recording ():
285
+ span .set_status (Status (StatusCode .ERROR , str (e )))
286
+ span .record_exception (e )
287
+ span .end ()
288
+ raise
125
289
126
- if response :
127
- _handle_response (span , event_logger , llm_request_type , response )
290
+ set_span_response_attributes ( span , response )
291
+ _handle_response_content (span , event_logger , llm_request_type , response )
128
292
129
293
return response
130
294
@@ -151,18 +315,51 @@ def _instrument(self, **kwargs):
151
315
__name__ , __version__ , event_logger_provider = event_logger_provider
152
316
)
153
317
for wrapped_method in WRAPPED_METHODS :
318
+ wrap_module = wrapped_method .get ("module" )
154
319
wrap_object = wrapped_method .get ("object" )
155
320
wrap_method = wrapped_method .get ("method" )
156
- wrap_function_wrapper (
157
- "cohere.client" ,
158
- f"{ wrap_object } .{ wrap_method } " ,
159
- _wrap (tracer , event_logger , wrapped_method ),
160
- )
321
+ try :
322
+ wrap_function_wrapper (
323
+ wrap_module ,
324
+ f"{ wrap_object } .{ wrap_method } " ,
325
+ _wrap (tracer , event_logger , wrapped_method ),
326
+ )
327
+ except (ImportError , ModuleNotFoundError , AttributeError ):
328
+ logger .debug (f"Failed to instrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
329
+
330
+ for wrapped_method in WRAPPED_AMETHODS :
331
+ wrap_module = wrapped_method .get ("module" )
332
+ wrap_object = wrapped_method .get ("object" )
333
+ wrap_method = wrapped_method .get ("method" )
334
+ try :
335
+ wrap_function_wrapper (
336
+ wrap_module ,
337
+ f"{ wrap_object } .{ wrap_method } " ,
338
+ _awrap (tracer , event_logger , wrapped_method ),
339
+ )
340
+ except (ImportError , ModuleNotFoundError , AttributeError ):
341
+ logger .debug (f"Failed to instrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
161
342
162
343
def _uninstrument (self , ** kwargs ):
163
344
for wrapped_method in WRAPPED_METHODS :
345
+ wrap_module = wrapped_method .get ("module" )
164
346
wrap_object = wrapped_method .get ("object" )
165
- unwrap (
166
- f"cohere.client.{ wrap_object } " ,
167
- wrapped_method .get ("method" ),
168
- )
347
+ wrap_method = wrapped_method .get ("method" )
348
+ try :
349
+ unwrap (
350
+ f"{ wrap_module } .{ wrap_object } " ,
351
+ wrap_method ,
352
+ )
353
+ except (ImportError , ModuleNotFoundError , AttributeError ):
354
+ logger .debug (f"Failed to uninstrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
355
+ for wrapped_method in WRAPPED_AMETHODS :
356
+ wrap_module = wrapped_method .get ("module" )
357
+ wrap_object = wrapped_method .get ("object" )
358
+ wrap_method = wrapped_method .get ("method" )
359
+ try :
360
+ unwrap (
361
+ f"{ wrap_module } .{ wrap_object } " ,
362
+ wrap_method ,
363
+ )
364
+ except (ImportError , ModuleNotFoundError , AttributeError ):
365
+ logger .debug (f"Failed to uninstrument { wrap_module } .{ wrap_object } .{ wrap_method } " )
0 commit comments