Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
guardrail_converse,
guardrail_handling,
)
from opentelemetry.instrumentation.bedrock.prompt_caching import prompt_caching_handling
from opentelemetry.instrumentation.bedrock.prompt_caching import (
prompt_caching_converse_handling,
prompt_caching_handling,
)
from opentelemetry.instrumentation.bedrock.reusable_streaming_body import (
ReusableStreamingBody,
)
Expand Down Expand Up @@ -354,6 +357,7 @@ def _handle_call(span: Span, kwargs, response, metric_params, event_logger):
def _handle_converse(span, kwargs, response, metric_params, event_logger):
(provider, model_vendor, model) = _get_vendor_model(kwargs.get("modelId"))
guardrail_converse(span, response, provider, model, metric_params)
prompt_caching_converse_handling(response, provider, model, metric_params)

set_converse_model_span_attributes(span, provider, model, kwargs)

Expand Down Expand Up @@ -394,7 +398,11 @@ def wrap(*args, **kwargs):
role = event["messageStart"]["role"]
elif "metadata" in event:
# last message sent
metadata = event.get("metadata", {})
guardrail_converse(span, event["metadata"], provider, model, metric_params)
prompt_caching_converse_handling(
metadata, provider, model, metric_params
)
converse_usage_record(span, event["metadata"], metric_params)
span.end()
elif "messageStop" in event:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,45 @@ def prompt_caching_handling(headers, vendor, model, metric_params):
)
if write_cached_tokens > 0:
span.set_attribute(CacheSpanAttrs.CACHED, "write")


def prompt_caching_converse_handling(response, vendor, model, metric_params):
base_attrs = {
"gen_ai.system": vendor,
"gen_ai.response.model": model,
}
span = trace.get_current_span()
if not isinstance(span, trace.Span) or not span.is_recording():
return

usage = response.get("usage", {})
read_cached_tokens = usage.get("cache_read_input_tokens", 0)
write_cached_tokens = usage.get("cache_creation_input_tokens", 0)

if read_cached_tokens > 0:
if metric_params.prompt_caching:
metric_params.prompt_caching.add(
read_cached_tokens,
attributes={
**base_attrs,
CacheSpanAttrs.TYPE: "read",
},
)
span.set_attribute(CacheSpanAttrs.CACHED, "read")
span.set_attribute(
"gen_ai.usage.cache_read_input_tokens", read_cached_tokens
)

if write_cached_tokens > 0:
if metric_params.prompt_caching:
metric_params.prompt_caching.add(
write_cached_tokens,
attributes={
**base_attrs,
CacheSpanAttrs.TYPE: "write",
},
)
span.set_attribute(CacheSpanAttrs.CACHED, "write")
span.set_attribute(
"gen_ai.usage.cache_creation_input_tokens", write_cached_tokens
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from opentelemetry.instrumentation.bedrock import PromptCaching
from opentelemetry.instrumentation.bedrock.prompt_caching import CacheSpanAttrs


def call(brt):
return brt.converse(
modelId="anthropic.claude-3-haiku-20240307-v1:0",
messages=[
{
"role": "user",
"content": [
{
"text": "What is the capital of the USA?",
}
],
}
],
inferenceConfig={"maxTokens": 50, "temperature": 0.1},
additionalModelRequestFields={"cacheControl": {"type": "ephemeral"}},
)


def get_metric(resource_metrics, name):
for rm in resource_metrics:
for sm in rm.scope_metrics:
for metric in sm.metrics:
if metric.name == name:
return metric
raise Exception(f"No metric found with name {name}")


def assert_metric(reader, usage):
metrics_data = reader.get_metrics_data()
resource_metrics = metrics_data.resource_metrics
assert len(resource_metrics) > 0

m = get_metric(resource_metrics, PromptCaching.LLM_BEDROCK_PROMPT_CACHING)
for data_point in m.data.data_points:
assert data_point.attributes[CacheSpanAttrs.TYPE] in [
"read",
"write",
]
if data_point.attributes[CacheSpanAttrs.TYPE] == "read":
assert data_point.value == usage["cache_read_input_tokens"]
else:
assert data_point.value == usage["cache_creation_input_tokens"]


@pytest.mark.vcr
def test_prompt_cache_converse(test_context, brt):
_, _, reader = test_context

response = call(brt)
# assert first prompt writes a cache
usage = response["usage"]
assert usage["cache_read_input_tokens"] == 0
assert usage["cache_creation_input_tokens"] > 0
cumulative_workaround = usage["cache_creation_input_tokens"]
assert_metric(reader, usage)

response = call(brt)
# assert second prompt reads from the cache
usage = response["usage"]
assert usage["cache_read_input_tokens"] > 0
assert usage["cache_creation_input_tokens"] == 0
# data is stored across reads of metric data due to the cumulative behavior
usage["cache_creation_input_tokens"] = cumulative_workaround
assert_metric(reader, usage)