Skip to content

Commit 7b994c0

Browse files
authored
fix(cohere): add v2 api instrumentation (#3378)
1 parent 5ac6e6a commit 7b994c0

26 files changed

+5267
-176
lines changed

packages/opentelemetry-instrumentation-cohere/opentelemetry/instrumentation/cohere/__init__.py

Lines changed: 221 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,16 @@
1111
emit_response_events,
1212
)
1313
from opentelemetry.instrumentation.cohere.span_utils import (
14-
set_input_attributes,
15-
set_response_attributes,
14+
set_input_content_attributes,
15+
set_response_content_attributes,
1616
set_span_request_attributes,
17+
set_span_response_attributes,
18+
)
19+
from opentelemetry.instrumentation.cohere.streaming import (
20+
process_chat_v1_streaming_response,
21+
aprocess_chat_v1_streaming_response,
22+
process_chat_v2_streaming_response,
23+
aprocess_chat_v2_streaming_response,
1724
)
1825
from opentelemetry.instrumentation.cohere.utils import dont_throw, should_emit_events
1926
from opentelemetry.instrumentation.cohere.version import __version__
@@ -27,7 +34,7 @@
2734
LLMRequestTypeValues,
2835
SpanAttributes,
2936
)
30-
from opentelemetry.trace import SpanKind, Tracer, get_tracer
37+
from opentelemetry.trace import SpanKind, Status, StatusCode, Tracer, get_tracer, use_span
3138
from wrapt import wrap_function_wrapper
3239

3340
logger = logging.getLogger(__name__)
@@ -36,20 +43,121 @@
3643

3744
WRAPPED_METHODS = [
3845
{
46+
"module": "cohere.client",
3947
"object": "Client",
4048
"method": "generate",
4149
"span_name": "cohere.completion",
4250
},
4351
{
52+
"module": "cohere.client",
4453
"object": "Client",
4554
"method": "chat",
4655
"span_name": "cohere.chat",
4756
},
4857
{
58+
"module": "cohere.client",
59+
"object": "Client",
60+
"method": "chat_stream",
61+
"span_name": "cohere.chat",
62+
"stream_process_func": process_chat_v1_streaming_response,
63+
},
64+
{
65+
"module": "cohere.client",
66+
"object": "Client",
67+
"method": "rerank",
68+
"span_name": "cohere.rerank",
69+
},
70+
{
71+
"module": "cohere.client",
4972
"object": "Client",
73+
"method": "embed",
74+
"span_name": "cohere.embed",
75+
},
76+
{
77+
"module": "cohere.client_v2",
78+
"object": "ClientV2",
79+
"method": "chat",
80+
"span_name": "cohere.chat",
81+
},
82+
{
83+
"module": "cohere.client_v2",
84+
"object": "ClientV2",
85+
"method": "chat_stream",
86+
"span_name": "cohere.chat",
87+
"stream_process_func": process_chat_v2_streaming_response,
88+
},
89+
{
90+
"module": "cohere.client_v2",
91+
"object": "ClientV2",
5092
"method": "rerank",
5193
"span_name": "cohere.rerank",
5294
},
95+
{
96+
"module": "cohere.client_v2",
97+
"object": "ClientV2",
98+
"method": "embed",
99+
"span_name": "cohere.embed",
100+
},
101+
# Async methods that return AsyncIterator must be wrapped with sync wrapper
102+
{
103+
"module": "cohere.client",
104+
"object": "AsyncClient",
105+
"method": "chat_stream",
106+
"span_name": "cohere.chat",
107+
"stream_process_func": aprocess_chat_v1_streaming_response,
108+
},
109+
{
110+
"module": "cohere.client_v2",
111+
"object": "AsyncClientV2",
112+
"method": "chat_stream",
113+
"span_name": "cohere.chat",
114+
"stream_process_func": aprocess_chat_v2_streaming_response,
115+
},
116+
]
117+
118+
WRAPPED_AMETHODS = [
119+
{
120+
"module": "cohere.client",
121+
"object": "AsyncClient",
122+
"method": "generate",
123+
"span_name": "cohere.completion",
124+
},
125+
{
126+
"module": "cohere.client",
127+
"object": "AsyncClient",
128+
"method": "chat",
129+
"span_name": "cohere.chat",
130+
},
131+
{
132+
"module": "cohere.client",
133+
"object": "AsyncClient",
134+
"method": "rerank",
135+
"span_name": "cohere.rerank",
136+
},
137+
{
138+
"module": "cohere.client",
139+
"object": "AsyncClient",
140+
"method": "embed",
141+
"span_name": "cohere.embed",
142+
},
143+
{
144+
"module": "cohere.client_v2",
145+
"object": "AsyncClientV2",
146+
"method": "chat",
147+
"span_name": "cohere.chat",
148+
},
149+
{
150+
"module": "cohere.client_v2",
151+
"object": "AsyncClientV2",
152+
"method": "rerank",
153+
"span_name": "cohere.rerank",
154+
},
155+
{
156+
"module": "cohere.client_v2",
157+
"object": "AsyncClientV2",
158+
"method": "embed",
159+
"span_name": "cohere.embed",
160+
},
53161
]
54162

55163

@@ -66,30 +174,30 @@ def wrapper(wrapped, instance, args, kwargs):
66174

67175

68176
def _llm_request_type_by_method(method_name):
69-
if method_name == "chat":
177+
if method_name in ["chat", "chat_stream"]:
70178
return LLMRequestTypeValues.CHAT
71-
elif method_name == "generate":
179+
elif method_name in ["generate", "generate_stream"]:
72180
return LLMRequestTypeValues.COMPLETION
73181
elif method_name == "rerank":
74182
return LLMRequestTypeValues.RERANK
183+
elif method_name == "embed":
184+
return LLMRequestTypeValues.EMBEDDING
75185
else:
76186
return LLMRequestTypeValues.UNKNOWN
77187

78188

79189
@dont_throw
80-
def _handle_input(span, event_logger, llm_request_type, kwargs):
190+
def _handle_input_content(span, event_logger, llm_request_type, kwargs):
191+
set_input_content_attributes(span, llm_request_type, kwargs)
81192
if should_emit_events():
82193
emit_input_event(event_logger, llm_request_type, kwargs)
83-
else:
84-
set_input_attributes(span, llm_request_type, kwargs)
85194

86195

87196
@dont_throw
88-
def _handle_response(span, event_logger, llm_request_type, response):
197+
def _handle_response_content(span, event_logger, llm_request_type, response):
198+
set_response_content_attributes(span, llm_request_type, response)
89199
if should_emit_events():
90200
emit_response_events(event_logger, llm_request_type, response)
91-
else:
92-
set_response_attributes(span, llm_request_type, response)
93201

94202

95203
@_with_tracer_wrapper
@@ -108,6 +216,55 @@ def _wrap(
108216
):
109217
return wrapped(*args, **kwargs)
110218

219+
name = to_wrap.get("span_name")
220+
llm_request_type = _llm_request_type_by_method(to_wrap.get("method"))
221+
span = tracer.start_span(
222+
name,
223+
kind=SpanKind.CLIENT,
224+
attributes={
225+
SpanAttributes.LLM_SYSTEM: "Cohere",
226+
SpanAttributes.LLM_REQUEST_TYPE: llm_request_type.value,
227+
},
228+
)
229+
230+
with use_span(span, end_on_exit=False):
231+
set_span_request_attributes(span, kwargs)
232+
_handle_input_content(span, event_logger, llm_request_type, kwargs)
233+
234+
try:
235+
response = wrapped(*args, **kwargs)
236+
except Exception as e:
237+
if span.is_recording():
238+
span.set_status(Status(StatusCode.ERROR, str(e)))
239+
span.record_exception(e)
240+
span.end()
241+
raise
242+
243+
if to_wrap.get("stream_process_func"):
244+
return to_wrap.get("stream_process_func")(span, event_logger, llm_request_type, response)
245+
246+
set_span_response_attributes(span, response)
247+
_handle_response_content(span, event_logger, llm_request_type, response)
248+
span.end()
249+
return response
250+
251+
252+
@_with_tracer_wrapper
253+
async def _awrap(
254+
tracer: Tracer,
255+
event_logger: Union[EventLogger, None],
256+
to_wrap,
257+
wrapped,
258+
instance,
259+
args,
260+
kwargs,
261+
):
262+
"""Instruments and calls every function defined in TO_WRAP."""
263+
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY) or context_api.get_value(
264+
SUPPRESS_LANGUAGE_MODEL_INSTRUMENTATION_KEY
265+
):
266+
return await wrapped(*args, **kwargs)
267+
111268
name = to_wrap.get("span_name")
112269
llm_request_type = _llm_request_type_by_method(to_wrap.get("method"))
113270
with tracer.start_as_current_span(
@@ -119,12 +276,19 @@ def _wrap(
119276
},
120277
) as span:
121278
set_span_request_attributes(span, kwargs)
122-
_handle_input(span, event_logger, llm_request_type, kwargs)
279+
_handle_input_content(span, event_logger, llm_request_type, kwargs)
123280

124-
response = wrapped(*args, **kwargs)
281+
try:
282+
response = await wrapped(*args, **kwargs)
283+
except Exception as e:
284+
if span.is_recording():
285+
span.set_status(Status(StatusCode.ERROR, str(e)))
286+
span.record_exception(e)
287+
span.end()
288+
raise
125289

126-
if response:
127-
_handle_response(span, event_logger, llm_request_type, response)
290+
set_span_response_attributes(span, response)
291+
_handle_response_content(span, event_logger, llm_request_type, response)
128292

129293
return response
130294

@@ -151,18 +315,51 @@ def _instrument(self, **kwargs):
151315
__name__, __version__, event_logger_provider=event_logger_provider
152316
)
153317
for wrapped_method in WRAPPED_METHODS:
318+
wrap_module = wrapped_method.get("module")
154319
wrap_object = wrapped_method.get("object")
155320
wrap_method = wrapped_method.get("method")
156-
wrap_function_wrapper(
157-
"cohere.client",
158-
f"{wrap_object}.{wrap_method}",
159-
_wrap(tracer, event_logger, wrapped_method),
160-
)
321+
try:
322+
wrap_function_wrapper(
323+
wrap_module,
324+
f"{wrap_object}.{wrap_method}",
325+
_wrap(tracer, event_logger, wrapped_method),
326+
)
327+
except (ImportError, ModuleNotFoundError, AttributeError):
328+
logger.debug(f"Failed to instrument {wrap_module}.{wrap_object}.{wrap_method}")
329+
330+
for wrapped_method in WRAPPED_AMETHODS:
331+
wrap_module = wrapped_method.get("module")
332+
wrap_object = wrapped_method.get("object")
333+
wrap_method = wrapped_method.get("method")
334+
try:
335+
wrap_function_wrapper(
336+
wrap_module,
337+
f"{wrap_object}.{wrap_method}",
338+
_awrap(tracer, event_logger, wrapped_method),
339+
)
340+
except (ImportError, ModuleNotFoundError, AttributeError):
341+
logger.debug(f"Failed to instrument {wrap_module}.{wrap_object}.{wrap_method}")
161342

162343
def _uninstrument(self, **kwargs):
163344
for wrapped_method in WRAPPED_METHODS:
345+
wrap_module = wrapped_method.get("module")
164346
wrap_object = wrapped_method.get("object")
165-
unwrap(
166-
f"cohere.client.{wrap_object}",
167-
wrapped_method.get("method"),
168-
)
347+
wrap_method = wrapped_method.get("method")
348+
try:
349+
unwrap(
350+
f"{wrap_module}.{wrap_object}",
351+
wrap_method,
352+
)
353+
except (ImportError, ModuleNotFoundError, AttributeError):
354+
logger.debug(f"Failed to uninstrument {wrap_module}.{wrap_object}.{wrap_method}")
355+
for wrapped_method in WRAPPED_AMETHODS:
356+
wrap_module = wrapped_method.get("module")
357+
wrap_object = wrapped_method.get("object")
358+
wrap_method = wrapped_method.get("method")
359+
try:
360+
unwrap(
361+
f"{wrap_module}.{wrap_object}",
362+
wrap_method,
363+
)
364+
except (ImportError, ModuleNotFoundError, AttributeError):
365+
logger.debug(f"Failed to uninstrument {wrap_module}.{wrap_object}.{wrap_method}")

0 commit comments

Comments
 (0)