Skip to content

Commit 9ee1229

Browse files
Handle AWS error codes in documents
1 parent 5e2e774 commit 9ee1229

File tree

4 files changed

+186
-10
lines changed

4 files changed

+186
-10
lines changed

packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from typing import Any, Final
22

33
from smithy_core.codecs import Codec
4+
from smithy_core.exceptions import DiscriminatorError
45
from smithy_core.schemas import APIOperation
5-
from smithy_core.shapes import ShapeID
6+
from smithy_core.shapes import ShapeID, ShapeType
67
from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPResponse
78
from smithy_http.aio.protocols import HttpBindingClientProtocol
8-
from smithy_json import JSONCodec
9+
from smithy_json import JSONCodec, JSONDocument
910

1011
from ..traits import RestJson1Trait
12+
from ..utils import parse_document_discriminator, parse_error_code
1113

1214

1315
class AWSErrorIdentifier(HTTPErrorIdentifier):
@@ -24,20 +26,29 @@ def identify(
2426

2527
error_field = response.fields[self._HEADER_KEY]
2628
code = error_field.values[0] if len(error_field.values) > 0 else None
27-
if not code:
28-
return None
29+
if code is not None:
30+
return parse_error_code(code, operation.schema.id.namespace)
31+
return None
32+
2933

30-
code = code.split(":")[0]
31-
if "#" in code:
32-
return ShapeID(code)
33-
return ShapeID.from_parts(name=code, namespace=operation.schema.id.namespace)
34+
class AWSJSONDocument(JSONDocument):
35+
@property
36+
def discriminator(self) -> ShapeID:
37+
if self.shape_type is ShapeType.STRUCTURE:
38+
return self._schema.id
39+
parsed = parse_document_discriminator(self, self._settings.default_namespace)
40+
if parsed is None:
41+
raise DiscriminatorError(
42+
f"Unable to parse discriminator for {self.shape_type} docuemnt."
43+
)
44+
return parsed
3445

3546

3647
class RestJsonClientProtocol(HttpBindingClientProtocol):
3748
"""An implementation of the aws.protocols#restJson1 protocol."""
3849

3950
_id: Final = RestJson1Trait.id
40-
_codec: Final = JSONCodec()
51+
_codec: Final = JSONCodec(document_class=AWSJSONDocument)
4152
_contentType: Final = "application/json"
4253
_error_identifier: Final = AWSErrorIdentifier()
4354

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from smithy_core.documents import Document
4+
from smithy_core.shapes import ShapeID, ShapeType
5+
6+
7+
def parse_document_discriminator(
8+
document: Document, default_namespace: str | None
9+
) -> ShapeID | None:
10+
if document.shape_type is ShapeType.MAP:
11+
map_document = document.as_map()
12+
code = map_document.get("__type")
13+
if code is None:
14+
code = map_document.get("code")
15+
if code is not None and code.shape_type is ShapeType.STRING:
16+
return parse_error_code(code.as_string(), default_namespace)
17+
18+
return None
19+
20+
21+
def parse_error_code(code: str, default_namespace: str | None) -> ShapeID | None:
22+
if not code:
23+
return None
24+
25+
code = code.split(":")[0]
26+
if "#" in code:
27+
return ShapeID(code)
28+
29+
if not code or not default_namespace:
30+
return None
31+
32+
return ShapeID.from_parts(name=code, namespace=default_namespace)

packages/smithy-aws-core/tests/unit/aio/test_protocols.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from unittest.mock import Mock
55

66
import pytest
7-
from smithy_aws_core.aio.protocols import AWSErrorIdentifier
7+
from smithy_aws_core.aio.protocols import AWSErrorIdentifier, AWSJSONDocument
8+
from smithy_core.exceptions import DiscriminatorError
89
from smithy_core.schemas import APIOperation, Schema
910
from smithy_core.shapes import ShapeID, ShapeType
1011
from smithy_http import Fields, tuples_to_fields
1112
from smithy_http.aio import HTTPResponse
13+
from smithy_json import JSONSettings
1214

1315

1416
@pytest.mark.parametrize(
@@ -24,6 +26,7 @@
2426
"com.test#FooError",
2527
),
2628
("", None),
29+
(":", None),
2730
(None, None),
2831
],
2932
)
@@ -42,3 +45,55 @@ def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> N
4245
actual = error_identifier.identify(operation=operation, response=http_response)
4346

4447
assert actual == expected
48+
49+
50+
@pytest.mark.parametrize(
51+
"document, expected",
52+
[
53+
({"__type": "FooError"}, "com.test#FooError"),
54+
({"__type": "com.test#FooError"}, "com.test#FooError"),
55+
(
56+
{
57+
"__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
58+
},
59+
"com.test#FooError",
60+
),
61+
(
62+
{
63+
"__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
64+
},
65+
"com.test#FooError",
66+
),
67+
({"code": "FooError"}, "com.test#FooError"),
68+
({"code": "com.test#FooError"}, "com.test#FooError"),
69+
(
70+
{
71+
"code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
72+
},
73+
"com.test#FooError",
74+
),
75+
(
76+
{
77+
"code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
78+
},
79+
"com.test#FooError",
80+
),
81+
({"__type": "FooError", "code": "BarError"}, "com.test#FooError"),
82+
("FooError", None),
83+
({"__type": None}, None),
84+
({"__type": ""}, None),
85+
({"__type": ":"}, None),
86+
],
87+
)
88+
def test_aws_json_document_discriminator(
89+
document: dict[str, str], expected: ShapeID | None
90+
) -> None:
91+
settings = JSONSettings(
92+
document_class=AWSJSONDocument, default_namespace="com.test"
93+
)
94+
if expected is None:
95+
with pytest.raises(DiscriminatorError):
96+
AWSJSONDocument(document, settings=settings).discriminator
97+
else:
98+
discriminator = AWSJSONDocument(document, settings=settings).discriminator
99+
assert discriminator == expected
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import pytest
5+
from smithy_aws_core.utils import parse_document_discriminator, parse_error_code
6+
from smithy_core.documents import Document
7+
from smithy_core.shapes import ShapeID
8+
9+
10+
@pytest.mark.parametrize(
11+
"document, expected",
12+
[
13+
({"__type": "FooError"}, "com.test#FooError"),
14+
({"__type": "com.test#FooError"}, "com.test#FooError"),
15+
(
16+
{
17+
"__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
18+
},
19+
"com.test#FooError",
20+
),
21+
(
22+
{
23+
"__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
24+
},
25+
"com.test#FooError",
26+
),
27+
({"code": "FooError"}, "com.test#FooError"),
28+
({"code": "com.test#FooError"}, "com.test#FooError"),
29+
(
30+
{
31+
"code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/"
32+
},
33+
"com.test#FooError",
34+
),
35+
(
36+
{
37+
"code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate"
38+
},
39+
"com.test#FooError",
40+
),
41+
({"__type": "FooError", "code": "BarError"}, "com.test#FooError"),
42+
("FooError", None),
43+
({"__type": None}, None),
44+
({"__type": ""}, None),
45+
({"__type": ":"}, None),
46+
],
47+
)
48+
def test_aws_json_document_discriminator(
49+
document: dict[str, str], expected: ShapeID | None
50+
) -> None:
51+
actual = parse_document_discriminator(Document(document), "com.test")
52+
assert actual == expected
53+
54+
55+
@pytest.mark.parametrize(
56+
"code, expected",
57+
[
58+
("FooError", "com.test#FooError"),
59+
(
60+
"FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/",
61+
"com.test#FooError",
62+
),
63+
(
64+
"com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate",
65+
"com.test#FooError",
66+
),
67+
("", None),
68+
(":", None),
69+
],
70+
)
71+
def test_parse_error_code(code: str, expected: ShapeID | None) -> None:
72+
actual = parse_error_code(code, "com.test")
73+
assert actual == expected
74+
75+
76+
def test_parse_error_code_without_default_namespace() -> None:
77+
actual = parse_error_code("FooError", None)
78+
assert actual is None

0 commit comments

Comments
 (0)