diff --git a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py index 9b1ec775bd..ba97c54a9e 100644 --- a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py +++ b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/__init__.py @@ -22,7 +22,10 @@ guardrail_converse, guardrail_handling, ) -from opentelemetry.instrumentation.bedrock.prompt_caching import prompt_caching_handling +from opentelemetry.instrumentation.bedrock.prompt_caching import ( + prompt_caching_converse_handling, + prompt_caching_handling, +) from opentelemetry.instrumentation.bedrock.reusable_streaming_body import ( ReusableStreamingBody, ) @@ -354,6 +357,7 @@ def _handle_call(span: Span, kwargs, response, metric_params, event_logger): def _handle_converse(span, kwargs, response, metric_params, event_logger): (provider, model_vendor, model) = _get_vendor_model(kwargs.get("modelId")) guardrail_converse(span, response, provider, model, metric_params) + prompt_caching_converse_handling(response, provider, model, metric_params) set_converse_model_span_attributes(span, provider, model, kwargs) @@ -394,7 +398,11 @@ def wrap(*args, **kwargs): role = event["messageStart"]["role"] elif "metadata" in event: # last message sent + metadata = event.get("metadata", {}) guardrail_converse(span, event["metadata"], provider, model, metric_params) + prompt_caching_converse_handling( + metadata, provider, model, metric_params + ) converse_usage_record(span, event["metadata"], metric_params) span.end() elif "messageStop" in event: diff --git a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/prompt_caching.py b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/prompt_caching.py index b94dc66127..f98d57978b 100644 --- a/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/prompt_caching.py +++ b/packages/opentelemetry-instrumentation-bedrock/opentelemetry/instrumentation/bedrock/prompt_caching.py @@ -41,3 +41,45 @@ def prompt_caching_handling(headers, vendor, model, metric_params): ) if write_cached_tokens > 0: span.set_attribute(CacheSpanAttrs.CACHED, "write") + + +def prompt_caching_converse_handling(response, vendor, model, metric_params): + base_attrs = { + "gen_ai.system": vendor, + "gen_ai.response.model": model, + } + span = trace.get_current_span() + if not isinstance(span, trace.Span) or not span.is_recording(): + return + + usage = response.get("usage", {}) + read_cached_tokens = usage.get("cache_read_input_tokens", 0) + write_cached_tokens = usage.get("cache_creation_input_tokens", 0) + + if read_cached_tokens > 0: + if metric_params.prompt_caching: + metric_params.prompt_caching.add( + read_cached_tokens, + attributes={ + **base_attrs, + CacheSpanAttrs.TYPE: "read", + }, + ) + span.set_attribute(CacheSpanAttrs.CACHED, "read") + span.set_attribute( + "gen_ai.usage.cache_read_input_tokens", read_cached_tokens + ) + + if write_cached_tokens > 0: + if metric_params.prompt_caching: + metric_params.prompt_caching.add( + write_cached_tokens, + attributes={ + **base_attrs, + CacheSpanAttrs.TYPE: "write", + }, + ) + span.set_attribute(CacheSpanAttrs.CACHED, "write") + span.set_attribute( + "gen_ai.usage.cache_creation_input_tokens", write_cached_tokens + ) diff --git a/packages/opentelemetry-instrumentation-bedrock/tests/metrics/test_bedrock_converse_prompt_caching_metrics.py b/packages/opentelemetry-instrumentation-bedrock/tests/metrics/test_bedrock_converse_prompt_caching_metrics.py new file mode 100644 index 0000000000..81cb370c35 --- /dev/null +++ b/packages/opentelemetry-instrumentation-bedrock/tests/metrics/test_bedrock_converse_prompt_caching_metrics.py @@ -0,0 +1,70 @@ +import pytest +from opentelemetry.instrumentation.bedrock import PromptCaching +from opentelemetry.instrumentation.bedrock.prompt_caching import CacheSpanAttrs + + +def call(brt): + return brt.converse( + modelId="anthropic.claude-3-haiku-20240307-v1:0", + messages=[ + { + "role": "user", + "content": [ + { + "text": "What is the capital of the USA?", + } + ], + } + ], + inferenceConfig={"maxTokens": 50, "temperature": 0.1}, + additionalModelRequestFields={"cacheControl": {"type": "ephemeral"}}, + ) + + +def get_metric(resource_metrics, name): + for rm in resource_metrics: + for sm in rm.scope_metrics: + for metric in sm.metrics: + if metric.name == name: + return metric + raise Exception(f"No metric found with name {name}") + + +def assert_metric(reader, usage): + metrics_data = reader.get_metrics_data() + resource_metrics = metrics_data.resource_metrics + assert len(resource_metrics) > 0 + + m = get_metric(resource_metrics, PromptCaching.LLM_BEDROCK_PROMPT_CACHING) + for data_point in m.data.data_points: + assert data_point.attributes[CacheSpanAttrs.TYPE] in [ + "read", + "write", + ] + if data_point.attributes[CacheSpanAttrs.TYPE] == "read": + assert data_point.value == usage["cache_read_input_tokens"] + else: + assert data_point.value == usage["cache_creation_input_tokens"] + + +@pytest.mark.vcr +def test_prompt_cache_converse(test_context, brt): + _, _, reader = test_context + + response = call(brt) + # assert first prompt writes a cache + usage = response["usage"] + assert usage["cache_read_input_tokens"] == 0 + assert usage["cache_creation_input_tokens"] > 0 + cumulative_workaround = usage["cache_creation_input_tokens"] + assert_metric(reader, usage) + + response = call(brt) + # assert second prompt reads from the cache + usage = response["usage"] + assert usage["cache_read_input_tokens"] > 0 + assert usage["cache_creation_input_tokens"] == 0 + # data is stored across reads of metric data due to the cumulative behavior + usage["cache_creation_input_tokens"] = cumulative_workaround + assert_metric(reader, usage) + \ No newline at end of file