Skip to content

Parse errors in http binding protocols #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
from typing import Final
from typing import Any, Final

from smithy_core.codecs import Codec
from smithy_core.schemas import APIOperation
from smithy_core.shapes import ShapeID
from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPResponse
from smithy_http.aio.protocols import HttpBindingClientProtocol
from smithy_json import JSONCodec

from ..traits import RestJson1Trait


class AWSErrorIdentifier(HTTPErrorIdentifier):
_HEADER_KEY: Final = "x-amzn-errortype"

def identify(
self,
*,
operation: APIOperation[Any, Any],
response: HTTPResponse,
) -> ShapeID | None:
if self._HEADER_KEY not in response.fields:
return None

code = response.fields[self._HEADER_KEY].values[0]
if not code:
return None

code = code.split(":")[0]
if "#" in code:
return ShapeID(code)
return ShapeID.from_parts(name=code, namespace=operation.schema.id.namespace)


class RestJsonClientProtocol(HttpBindingClientProtocol):
"""An implementation of the aws.protocols#restJson1 protocol."""

_id: ShapeID = RestJson1Trait.id
_codec: JSONCodec = JSONCodec()
_id: Final = RestJson1Trait.id
_codec: Final = JSONCodec()
_contentType: Final = "application/json"
_error_identifier: Final = AWSErrorIdentifier()

@property
def id(self) -> ShapeID:
Expand All @@ -26,3 +51,7 @@ def payload_codec(self) -> Codec:
@property
def content_type(self) -> str:
return self._contentType

@property
def error_identifier(self) -> HTTPErrorIdentifier:
return self._error_identifier
3 changes: 2 additions & 1 deletion packages/smithy-aws-core/src/smithy_aws_core/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field

from smithy_core.documents import DocumentValue
from smithy_core.shapes import ShapeID
from smithy_core.traits import DocumentValue, DynamicTrait, Trait
from smithy_core.traits import DynamicTrait, Trait


@dataclass(init=False, frozen=True)
Expand Down
Empty file.
44 changes: 44 additions & 0 deletions packages/smithy-aws-core/tests/unit/aio/test_protocols.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from unittest.mock import Mock

import pytest
from smithy_aws_core.aio.protocols import AWSErrorIdentifier
from smithy_core.schemas import APIOperation, Schema
from smithy_core.shapes import ShapeID, ShapeType
from smithy_http import Fields, tuples_to_fields
from smithy_http.aio import HTTPResponse


@pytest.mark.parametrize(
"header, expected",
[
("FooError", "com.test#FooError"),
(
"FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/",
"com.test#FooError",
),
(
"com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate",
"com.test#FooError",
),
("", None),
(None, None),
],
)
def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> None:
fields = Fields()
if header is not None:
fields = tuples_to_fields([("x-amzn-errortype", header)])
http_response = HTTPResponse(status=500, fields=fields)

operation = Mock(spec=APIOperation)
operation.schema = Schema(
id=ShapeID("com.test#TestOperation"), shape_type=ShapeType.OPERATION
)

error_identifier = AWSErrorIdentifier()
actual = error_identifier.identify(operation=operation, response=http_response)

assert actual == expected
8 changes: 8 additions & 0 deletions packages/smithy-core/src/smithy_core/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def is_bytes_reader(obj: Any) -> TypeGuard[BytesReader]:
)


@runtime_checkable
class SeekableBytesReader(BytesReader, Protocol):
"""A synchronous bytes reader with seek and tell methods."""

def tell(self) -> int: ...
def seek(self, offset: int, whence: int = 0, /) -> int: ...


# A union of all acceptable streaming blob types. Deserialized payloads will
# always return a ByteStream, or AsyncByteStream if async is enabled.
type StreamingBlob = BytesReader | bytes | bytearray
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Protocol
from typing import Any, Protocol

from smithy_core.aio.interfaces import ClientTransport, Request, Response
from smithy_core.aio.utils import read_streaming_blob, read_streaming_blob_async
from smithy_core.schemas import APIOperation
from smithy_core.shapes import ShapeID

from ...interfaces import (
Fields,
Expand Down Expand Up @@ -83,3 +85,19 @@ async def send(
:param request_config: Configuration specific to this request.
"""
...


class HTTPErrorIdentifier:
"""A class that uses HTTP response metadata to identify errors.

The body of the response SHOULD NOT be touched by this. The payload codec will be
used instead to check for an ID in the body.
"""

def identify(
self,
*,
operation: APIOperation[Any, Any],
response: HTTPResponse,
) -> ShapeID | None:
"""Idenitfy the ShapeID of an error from an HTTP response."""
107 changes: 97 additions & 10 deletions packages/smithy-http/src/smithy_http/aio/protocols.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import os
from inspect import iscoroutinefunction
from io import BytesIO
from typing import Any

from smithy_core.aio.interfaces import ClientProtocol
from smithy_core.codecs import Codec
from smithy_core.deserializers import DeserializeableShape
from smithy_core.documents import TypeRegistry
from smithy_core.exceptions import ExpectationNotMetError
from smithy_core.interfaces import Endpoint, TypedProperties, URI
from smithy_core.exceptions import CallError, ExpectationNotMetError, ModeledError
from smithy_core.interfaces import (
Endpoint,
SeekableBytesReader,
TypedProperties,
URI,
is_streaming_blob,
)
from smithy_core.interfaces import StreamingBlob as SyncStreamingBlob
from smithy_core.prelude import DOCUMENT
from smithy_core.schemas import APIOperation
from smithy_core.serializers import SerializeableShape
from smithy_core.traits import EndpointTrait, HTTPTrait

from smithy_http.aio.interfaces import HTTPRequest, HTTPResponse
from smithy_http.deserializers import HTTPResponseDeserializer
from smithy_http.serializers import HTTPRequestSerializer
from ..deserializers import HTTPResponseDeserializer
from ..serializers import HTTPRequestSerializer
from .interfaces import HTTPErrorIdentifier, HTTPRequest, HTTPResponse


class HttpClientProtocol(ClientProtocol[HTTPRequest, HTTPResponse]):
Expand Down Expand Up @@ -54,6 +63,12 @@ def content_type(self) -> str:
"""The media type of the http payload."""
raise NotImplementedError()

@property
def error_identifier(self) -> HTTPErrorIdentifier:
"""The class used to identify the shape IDs of errors based on fields or other
response information."""
raise NotImplementedError()

def serialize_request[
OperationInput: "SerializeableShape",
OperationOutput: "DeserializeableShape",
Expand Down Expand Up @@ -94,19 +109,25 @@ async def deserialize_response[
error_registry: TypeRegistry,
context: TypedProperties,
) -> OperationOutput:
if not (200 <= response.status <= 299):
# TODO: implement error serde from type registry
raise NotImplementedError

body = response.body

# if body is not streaming and is async, we have to buffer it
if not operation.output_stream_member:
if not operation.output_stream_member and not is_streaming_blob(body):
if (
read := getattr(body, "read", None)
) is not None and iscoroutinefunction(read):
body = BytesIO(await read())

if not self._is_success(operation, context, response):
raise await self._create_error(
operation=operation,
request=request,
response=response,
response_body=body, # type: ignore
error_registry=error_registry,
context=context,
)

# TODO(optimization): response binding cache like done in SJ
deserializer = HTTPResponseDeserializer(
payload_codec=self.payload_codec,
Expand All @@ -116,3 +137,69 @@ async def deserialize_response[
)

return operation.output.deserialize(deserializer)

def _is_success(
self,
operation: APIOperation[Any, Any],
context: TypedProperties,
response: HTTPResponse,
) -> bool:
return 200 <= response.status <= 299

async def _create_error(
self,
operation: APIOperation[Any, Any],
request: HTTPRequest,
response: HTTPResponse,
response_body: SyncStreamingBlob,
error_registry: TypeRegistry,
context: TypedProperties,
) -> CallError:
error_id = self.error_identifier.identify(
operation=operation, response=response
)

if error_id is None:
if isinstance(response_body, bytearray):
response_body = bytes(response_body)
deserializer = self.payload_codec.create_deserializer(source=response_body)
document = deserializer.read_document(schema=DOCUMENT)

if document.discriminator in error_registry:
error_id = document.discriminator
if isinstance(response_body, SeekableBytesReader):
response_body.seek(0)

if error_id is not None and error_id in error_registry:
error_shape = error_registry.get(error_id)

# make sure the error shape is derived from modeled exception
if not issubclass(error_shape, ModeledError):
raise ExpectationNotMetError(
f"Modeled errors must be derived from 'ModeledError', "
f"but got {error_shape}"
)

deserializer = HTTPResponseDeserializer(
payload_codec=self.payload_codec,
http_trait=operation.schema.expect_trait(HTTPTrait),
response=response,
body=response_body,
)
return error_shape.deserialize(deserializer)

is_throttle = response.status == 429
message = (
f"Unknown error for operation {operation.schema.id} "
f"- status: {response.status}"
)
if error_id is not None:
message += f" - id: {error_id}"
if response.reason is not None:
message += f" - reason: {response.status}"
return CallError(
message=message,
fault="client" if response.status < 500 else "server",
is_throttling_error=is_throttle,
is_retry_safe=is_throttle or None,
)
5 changes: 3 additions & 2 deletions packages/smithy-http/src/smithy_http/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ class HTTPResponseDeserializer(SpecificShapeDeserializer):
# Note: caller will have to read the body if it's async and not streaming
def __init__(
self,
*,
payload_codec: Codec,
http_trait: HTTPTrait,
response: HTTPResponse,
http_trait: HTTPTrait | None = None,
body: "SyncStreamingBlob | None" = None,
) -> None:
"""Initialize an HTTPResponseDeserializer.

:param payload_codec: The Codec to use to deserialize the payload, if present.
:param http_trait: The HTTP trait of the operation being handled.
:param response: The HTTP response to read from.
:param http_trait: The HTTP trait of the operation being handled.
:param body: The HTTP response body in a synchronously readable form. This is
necessary for async response bodies when there is no streaming member.
"""
Expand Down