From b0690f33ad34cc184a047047819a597fcd86c5f8 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Wed, 7 May 2025 16:57:34 +0200 Subject: [PATCH 1/5] Only have default discriminator for struct docs --- packages/smithy-core/src/smithy_core/documents.py | 6 ++++-- packages/smithy-core/src/smithy_core/exceptions.py | 5 +++++ packages/smithy-core/tests/unit/test_documents.py | 12 +++++++++++- .../smithy-core/tests/unit/test_type_registry.py | 13 ++++++++++--- 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/packages/smithy-core/src/smithy_core/documents.py b/packages/smithy-core/src/smithy_core/documents.py index 2166bae1..1c565fa9 100644 --- a/packages/smithy-core/src/smithy_core/documents.py +++ b/packages/smithy-core/src/smithy_core/documents.py @@ -5,7 +5,7 @@ from typing import TypeGuard, override from .deserializers import DeserializeableShape, ShapeDeserializer -from .exceptions import ExpectationNotMetError, SmithyError +from .exceptions import DiscriminatorError, ExpectationNotMetError, SmithyError from .schemas import Schema from .serializers import ( InterceptingSerializer, @@ -146,7 +146,9 @@ def shape_type(self) -> ShapeType: @property def discriminator(self) -> ShapeID: """The shape ID that corresponds to the contents of the document.""" - return self._schema.id + if self._type is ShapeType.STRUCTURE: + return self._schema.id + raise DiscriminatorError(f"{self._type} document has no discriminator.") def is_none(self) -> bool: """Indicates whether the document contains a null value.""" diff --git a/packages/smithy-core/src/smithy_core/exceptions.py b/packages/smithy-core/src/smithy_core/exceptions.py index 7d320b76..0e28bd53 100644 --- a/packages/smithy-core/src/smithy_core/exceptions.py +++ b/packages/smithy-core/src/smithy_core/exceptions.py @@ -65,6 +65,11 @@ class SerializationError(SmithyError): """Base exception type for exceptions raised during serialization.""" +class DiscriminatorError(SmithyError): + """Exception indicating something went wrong when attempting to find the + discriminator in a document.""" + + class RetryError(SmithyError): """Base exception type for all exceptions raised in retry strategies.""" diff --git a/packages/smithy-core/tests/unit/test_documents.py b/packages/smithy-core/tests/unit/test_documents.py index 1ae24eb9..8a1e3ccf 100644 --- a/packages/smithy-core/tests/unit/test_documents.py +++ b/packages/smithy-core/tests/unit/test_documents.py @@ -12,7 +12,7 @@ _DocumentDeserializer, _DocumentSerializer, ) -from smithy_core.exceptions import ExpectationNotMetError +from smithy_core.exceptions import DiscriminatorError, ExpectationNotMetError from smithy_core.prelude import ( BIG_DECIMAL, BLOB, @@ -938,3 +938,13 @@ def _read_optional_map(k: str, d: ShapeDeserializer): actual = given.as_shape(DocumentSerdeShape) case _: raise Exception(f"Unexpected type: {type(given)}") + + +def test_document_has_no_discriminator_by_default() -> None: + with pytest.raises(DiscriminatorError): + Document().discriminator + + +def test_struct_document_has_discriminator() -> None: + document = Document({"integerMember": 1}, schema=SCHEMA) + assert document.discriminator == SCHEMA.id diff --git a/packages/smithy-core/tests/unit/test_type_registry.py b/packages/smithy-core/tests/unit/test_type_registry.py index 3c7ca98c..6a2a4ffc 100644 --- a/packages/smithy-core/tests/unit/test_type_registry.py +++ b/packages/smithy-core/tests/unit/test_type_registry.py @@ -1,8 +1,10 @@ import pytest from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer from smithy_core.documents import Document, TypeRegistry +from smithy_core.prelude import STRING from smithy_core.schemas import Schema -from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.shapes import ShapeID +from smithy_core.traits import RequiredTrait def test_get(): @@ -59,11 +61,16 @@ def test_deserialize(): class TestShape(DeserializeableShape): __test__ = False - schema = Schema(id=ShapeID("com.example#Test"), shape_type=ShapeType.STRING) + schema = Schema.collection( + id=ShapeID("com.example#Test"), + members={"value": {"index": 0, "target": STRING, "traits": [RequiredTrait()]}}, + ) def __init__(self, value: str): self.value = value @classmethod def deserialize(cls, deserializer: ShapeDeserializer) -> "TestShape": - return TestShape(deserializer.read_string(schema=TestShape.schema)) + return TestShape( + value=deserializer.read_string(schema=cls.schema.members["value"]) + ) From aa7297217c92dcddc5917d30e91a87b80c8ef80e Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Wed, 7 May 2025 17:26:26 +0200 Subject: [PATCH 2/5] Use shared settings for JSON codec --- .../smithy-json/src/smithy_json/__init__.py | 28 +++++---------- .../src/smithy_json/_private/__init__.py | 10 ++++++ .../src/smithy_json/_private/deserializers.py | 34 +++++-------------- .../src/smithy_json/_private/documents.py | 33 ++++++++---------- .../src/smithy_json/_private/serializers.py | 23 ++++--------- 5 files changed, 48 insertions(+), 80 deletions(-) diff --git a/packages/smithy-json/src/smithy_json/__init__.py b/packages/smithy-json/src/smithy_json/__init__.py index c90653d0..53675435 100644 --- a/packages/smithy-json/src/smithy_json/__init__.py +++ b/packages/smithy-json/src/smithy_json/__init__.py @@ -9,6 +9,7 @@ from smithy_core.serializers import ShapeSerializer from smithy_core.types import TimestampFormat +from ._private import JSONSettings as _JSONSettings from ._private.deserializers import JSONShapeDeserializer as _JSONShapeDeserializer from ._private.serializers import JSONShapeSerializer as _JSONShapeSerializer @@ -18,15 +19,12 @@ class JSONCodec(Codec): """A codec for converting shapes to/from JSON.""" - _use_json_name: bool - _use_timestamp_format: bool - _default_timestamp_format: TimestampFormat - def __init__( self, use_json_name: bool = True, use_timestamp_format: bool = True, default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + default_namespace: str | None = None, ) -> None: """Initializes a JSONCodec. @@ -37,28 +35,20 @@ def __init__( :param default_timestamp_format: The default timestamp format to use if the `smithy.api#timestampFormat` trait is not enabled or not present. """ - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format + self._settings = _JSONSettings( + use_json_name=use_json_name, + use_timestamp_format=use_timestamp_format, + default_timestamp_format=default_timestamp_format, + ) @property def media_type(self) -> str: return "application/json" def create_serializer(self, sink: BytesWriter) -> "ShapeSerializer": - return _JSONShapeSerializer( - sink, - use_json_name=self._use_json_name, - use_timestamp_format=self._use_timestamp_format, - default_timestamp_format=self._default_timestamp_format, - ) + return _JSONShapeSerializer(sink, settings=self._settings) def create_deserializer(self, source: bytes | BytesReader) -> "ShapeDeserializer": if isinstance(source, bytes): source = BytesIO(source) - return _JSONShapeDeserializer( - source, - use_json_name=self._use_json_name, - use_timestamp_format=self._use_timestamp_format, - default_timestamp_format=self._default_timestamp_format, - ) + return _JSONShapeDeserializer(source, settings=self._settings) diff --git a/packages/smithy-json/src/smithy_json/_private/__init__.py b/packages/smithy-json/src/smithy_json/_private/__init__.py index d5e52e68..faa62e92 100644 --- a/packages/smithy-json/src/smithy_json/_private/__init__.py +++ b/packages/smithy-json/src/smithy_json/_private/__init__.py @@ -1,11 +1,21 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass from typing import Protocol, runtime_checkable +from smithy_core.types import TimestampFormat + @runtime_checkable class Flushable(Protocol): """A protocol for objects that can be flushed.""" def flush(self) -> None: ... + + +@dataclass +class JSONSettings: + use_json_name: bool = True + use_timestamp_format: bool = True + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME diff --git a/packages/smithy-json/src/smithy_json/_private/deserializers.py b/packages/smithy-json/src/smithy_json/_private/deserializers.py index ac79f646..8607c85e 100644 --- a/packages/smithy-json/src/smithy_json/_private/deserializers.py +++ b/packages/smithy-json/src/smithy_json/_private/deserializers.py @@ -18,6 +18,7 @@ from smithy_core.traits import JSONNameTrait, TimestampFormatTrait from smithy_core.types import TimestampFormat +from . import JSONSettings from .documents import JSONDocument # TODO: put these type hints in a pyi somewhere. There here because ijson isn't @@ -89,17 +90,10 @@ def peek(self) -> JSONParseEvent: class JSONShapeDeserializer(ShapeDeserializer): def __init__( - self, - source: BytesReader, - *, - use_json_name: bool = True, - use_timestamp_format: bool = True, - default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + self, source: BytesReader, *, settings: JSONSettings | None = None ) -> None: self._stream = BufferedParser(ijson.parse(source)) - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format + self._settings = settings or JSONSettings() # A mapping of json name to member name for each shape. Since the deserializer # is shared and we don't know which shapes will be deserialized, this is @@ -164,13 +158,7 @@ def read_string(self, schema: Schema) -> str: def read_document(self, schema: Schema) -> Document: start = next(self._stream) if start.type not in ("start_map", "start_array"): - return JSONDocument( - start.value, - schema=schema, - use_json_name=self._use_json_name, - default_timestamp_format=self._default_timestamp_format, - use_timestamp_format=self._use_timestamp_format, - ) + return JSONDocument(start.value, schema=schema, settings=self._settings) end_type = "end_map" if start.type == "start_map" else "end_array" builder = cast(TypedObjectBuilder, ObjectBuilder()) @@ -180,17 +168,11 @@ def read_document(self, schema: Schema) -> Document: ).path != start.path or event.type != end_type: builder.event(event.type, event.value) - return JSONDocument( - builder.value, - schema=schema, - use_json_name=self._use_json_name, - default_timestamp_format=self._default_timestamp_format, - use_timestamp_format=self._use_timestamp_format, - ) + return JSONDocument(builder.value, schema=schema, settings=self._settings) def read_timestamp(self, schema: Schema) -> datetime.datetime: - format = self._default_timestamp_format - if self._use_timestamp_format: + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: if format_trait := schema.get_trait(TimestampFormatTrait): format = format_trait.format @@ -221,7 +203,7 @@ def read_struct( next(self._stream) def _resolve_member(self, schema: Schema, key: str) -> Schema | None: - if self._use_json_name: + if self._settings.use_json_name: if schema.id not in self._json_names: self._cache_json_names(schema=schema) if key in self._json_names[schema.id]: diff --git a/packages/smithy-json/src/smithy_json/_private/documents.py b/packages/smithy-json/src/smithy_json/_private/documents.py index 0ad962a4..0fcdd07e 100644 --- a/packages/smithy-json/src/smithy_json/_private/documents.py +++ b/packages/smithy-json/src/smithy_json/_private/documents.py @@ -9,11 +9,12 @@ from smithy_core.documents import Document, DocumentValue from smithy_core.prelude import DOCUMENT from smithy_core.schemas import Schema -from smithy_core.shapes import ShapeType +from smithy_core.shapes import ShapeID, ShapeType from smithy_core.traits import JSONNameTrait, TimestampFormatTrait -from smithy_core.types import TimestampFormat from smithy_core.utils import expect_type +from . import JSONSettings + class JSONDocument(Document): _schema: Schema @@ -24,17 +25,13 @@ def __init__( value: DocumentValue | dict[str, "Document"] | list["Document"], *, schema: Schema = DOCUMENT, - use_json_name: bool = True, - use_timestamp_format: bool = True, - default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + settings: JSONSettings | None = None, ) -> None: super().__init__(value, schema=schema) - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format + self._settings = settings or JSONSettings() self._json_names = {} - if use_json_name and schema.shape_type in ( + if self._settings.use_json_name and schema.shape_type in ( ShapeType.STRUCTURE, ShapeType.UNION, ): @@ -42,6 +39,12 @@ def __init__( if json_name := member_schema.get_trait(JSONNameTrait): self._json_names[json_name.value] = member_name + @property + def discriminator(self) -> ShapeID: + if self._type is ShapeType.MAP: + return ShapeID(self.as_map()["__type"].as_string()) + return super().discriminator + def as_blob(self) -> bytes: return b64decode(expect_type(str, self._value)) @@ -51,8 +54,8 @@ def as_float(self) -> float: return float(expect_type(Decimal, self._value)) def as_timestamp(self) -> datetime: - format = self._default_timestamp_format - if self._use_timestamp_format: + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: if format_trait := self._schema.get_trait(TimestampFormatTrait): format = format_trait.format @@ -106,13 +109,7 @@ def _new_document( value: DocumentValue | dict[str, "Document"] | list["Document"], schema: Schema, ) -> "Document": - return JSONDocument( - value, - schema=schema, - use_json_name=self._use_json_name, - use_timestamp_format=self._use_timestamp_format, - default_timestamp_format=self._default_timestamp_format, - ) + return JSONDocument(value, schema=schema, settings=self._settings) def _wrap_map(self, value: Mapping[str, DocumentValue]) -> dict[str, "Document"]: if self._schema.shape_type not in (ShapeType.STRUCTURE, ShapeType.UNION): diff --git a/packages/smithy-json/src/smithy_json/_private/serializers.py b/packages/smithy-json/src/smithy_json/_private/serializers.py index 9c56a9da..e3a0051c 100644 --- a/packages/smithy-json/src/smithy_json/_private/serializers.py +++ b/packages/smithy-json/src/smithy_json/_private/serializers.py @@ -21,34 +21,23 @@ from smithy_core.traits import JSONNameTrait, TimestampFormatTrait from smithy_core.types import TimestampFormat -from . import Flushable +from . import Flushable, JSONSettings _INF: float = float("inf") _NEG_INF: float = float("-inf") class JSONShapeSerializer(ShapeSerializer): - _stream: "StreamingJSONEncoder" - _use_json_name: bool - _use_timestamp_format: bool - _default_timestamp_format: TimestampFormat - def __init__( - self, - sink: BytesWriter, - use_json_name: bool = True, - use_timestamp_format: bool = True, - default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, + self, sink: BytesWriter, *, settings: JSONSettings | None = None ) -> None: self._stream = StreamingJSONEncoder(sink) - self._use_json_name = use_json_name - self._use_timestamp_format = use_timestamp_format - self._default_timestamp_format = default_timestamp_format + self._settings = settings or JSONSettings() def begin_struct( self, schema: "Schema" ) -> AbstractContextManager["ShapeSerializer"]: - return JSONStructSerializer(self._stream, self, self._use_json_name) + return JSONStructSerializer(self._stream, self, self._settings.use_json_name) def begin_list( self, schema: "Schema", size: int @@ -82,8 +71,8 @@ def write_blob(self, schema: "Schema", value: bytes) -> None: self._stream.write_string(b64encode(value).decode("utf-8")) def write_timestamp(self, schema: "Schema", value: datetime) -> None: - format = self._default_timestamp_format - if self._use_timestamp_format: + format = self._settings.default_timestamp_format + if self._settings.use_timestamp_format: if format_trait := schema.get_trait(TimestampFormatTrait): format = format_trait.format From 5e2e774d811f4264cdc549918339bff66fdfb965 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Fri, 9 May 2025 13:38:01 +0200 Subject: [PATCH 3/5] Pass JSON document class through settings This updates the JSON codec settings to include the document class to use during deserialization. This is primarily intended to allow customization of the discriminator. --- .../smithy-json/src/smithy_json/__init__.py | 12 +++++-- .../src/smithy_json/_private/__init__.py | 11 ------- .../src/smithy_json/_private/deserializers.py | 17 +++++----- .../src/smithy_json/_private/documents.py | 17 +++------- .../src/smithy_json/_private/serializers.py | 9 +++--- .../smithy-json/src/smithy_json/settings.py | 31 +++++++++++++++++++ .../tests/unit/test_deserializers.py | 13 +++++++- 7 files changed, 71 insertions(+), 39 deletions(-) create mode 100644 packages/smithy-json/src/smithy_json/settings.py diff --git a/packages/smithy-json/src/smithy_json/__init__.py b/packages/smithy-json/src/smithy_json/__init__.py index 53675435..2993deed 100644 --- a/packages/smithy-json/src/smithy_json/__init__.py +++ b/packages/smithy-json/src/smithy_json/__init__.py @@ -9,11 +9,13 @@ from smithy_core.serializers import ShapeSerializer from smithy_core.types import TimestampFormat -from ._private import JSONSettings as _JSONSettings from ._private.deserializers import JSONShapeDeserializer as _JSONShapeDeserializer +from ._private.documents import JSONDocument from ._private.serializers import JSONShapeSerializer as _JSONShapeSerializer +from .settings import JSONSettings __version__: str = importlib.metadata.version("smithy-json") +__all__ = ("JSONCodec", "JSONDocument", "JSONSettings") class JSONCodec(Codec): @@ -25,6 +27,7 @@ def __init__( use_timestamp_format: bool = True, default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME, default_namespace: str | None = None, + document_class: type[JSONDocument] = JSONDocument, ) -> None: """Initializes a JSONCodec. @@ -34,11 +37,16 @@ def __init__( `smithy.api#timestampFormat` trait, if present. :param default_timestamp_format: The default timestamp format to use if the `smithy.api#timestampFormat` trait is not enabled or not present. + :param default_namespace: The default namespace to use when determining a + document's discriminator. + :param document_class: The document class to deserialize to. """ - self._settings = _JSONSettings( + self._settings = JSONSettings( use_json_name=use_json_name, use_timestamp_format=use_timestamp_format, default_timestamp_format=default_timestamp_format, + default_namespace=default_namespace, + document_class=document_class, ) @property diff --git a/packages/smithy-json/src/smithy_json/_private/__init__.py b/packages/smithy-json/src/smithy_json/_private/__init__.py index faa62e92..04559d52 100644 --- a/packages/smithy-json/src/smithy_json/_private/__init__.py +++ b/packages/smithy-json/src/smithy_json/_private/__init__.py @@ -1,21 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 - -from dataclasses import dataclass from typing import Protocol, runtime_checkable -from smithy_core.types import TimestampFormat - @runtime_checkable class Flushable(Protocol): """A protocol for objects that can be flushed.""" def flush(self) -> None: ... - - -@dataclass -class JSONSettings: - use_json_name: bool = True - use_timestamp_format: bool = True - default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME diff --git a/packages/smithy-json/src/smithy_json/_private/deserializers.py b/packages/smithy-json/src/smithy_json/_private/deserializers.py index 8607c85e..bbbb1692 100644 --- a/packages/smithy-json/src/smithy_json/_private/deserializers.py +++ b/packages/smithy-json/src/smithy_json/_private/deserializers.py @@ -18,8 +18,7 @@ from smithy_core.traits import JSONNameTrait, TimestampFormatTrait from smithy_core.types import TimestampFormat -from . import JSONSettings -from .documents import JSONDocument +from ..settings import JSONSettings # TODO: put these type hints in a pyi somewhere. There here because ijson isn't # typed. @@ -89,11 +88,9 @@ def peek(self) -> JSONParseEvent: class JSONShapeDeserializer(ShapeDeserializer): - def __init__( - self, source: BytesReader, *, settings: JSONSettings | None = None - ) -> None: + def __init__(self, source: BytesReader, settings: JSONSettings) -> None: self._stream = BufferedParser(ijson.parse(source)) - self._settings = settings or JSONSettings() + self._settings = settings # A mapping of json name to member name for each shape. Since the deserializer # is shared and we don't know which shapes will be deserialized, this is @@ -158,7 +155,9 @@ def read_string(self, schema: Schema) -> str: def read_document(self, schema: Schema) -> Document: start = next(self._stream) if start.type not in ("start_map", "start_array"): - return JSONDocument(start.value, schema=schema, settings=self._settings) + return self._settings.document_class( + value=start.value, schema=schema, settings=self._settings + ) end_type = "end_map" if start.type == "start_map" else "end_array" builder = cast(TypedObjectBuilder, ObjectBuilder()) @@ -168,7 +167,9 @@ def read_document(self, schema: Schema) -> Document: ).path != start.path or event.type != end_type: builder.event(event.type, event.value) - return JSONDocument(builder.value, schema=schema, settings=self._settings) + return self._settings.document_class( + value=builder.value, schema=schema, settings=self._settings + ) def read_timestamp(self, schema: Schema) -> datetime.datetime: format = self._settings.default_timestamp_format diff --git a/packages/smithy-json/src/smithy_json/_private/documents.py b/packages/smithy-json/src/smithy_json/_private/documents.py index 0fcdd07e..7c013851 100644 --- a/packages/smithy-json/src/smithy_json/_private/documents.py +++ b/packages/smithy-json/src/smithy_json/_private/documents.py @@ -13,23 +13,22 @@ from smithy_core.traits import JSONNameTrait, TimestampFormatTrait from smithy_core.utils import expect_type -from . import JSONSettings +from ..settings import JSONSettings class JSONDocument(Document): - _schema: Schema - _json_names: dict[str, str] + _discriminator: ShapeID | None = None def __init__( self, value: DocumentValue | dict[str, "Document"] | list["Document"], + settings: JSONSettings | None = None, *, schema: Schema = DOCUMENT, - settings: JSONSettings | None = None, ) -> None: super().__init__(value, schema=schema) - self._settings = settings or JSONSettings() - self._json_names = {} + self._settings = settings or JSONSettings(document_class=type(self)) + self._json_names: dict[str, str] = {} if self._settings.use_json_name and schema.shape_type in ( ShapeType.STRUCTURE, @@ -39,12 +38,6 @@ def __init__( if json_name := member_schema.get_trait(JSONNameTrait): self._json_names[json_name.value] = member_name - @property - def discriminator(self) -> ShapeID: - if self._type is ShapeType.MAP: - return ShapeID(self.as_map()["__type"].as_string()) - return super().discriminator - def as_blob(self) -> bytes: return b64decode(expect_type(str, self._value)) diff --git a/packages/smithy-json/src/smithy_json/_private/serializers.py b/packages/smithy-json/src/smithy_json/_private/serializers.py index e3a0051c..c1cd3df7 100644 --- a/packages/smithy-json/src/smithy_json/_private/serializers.py +++ b/packages/smithy-json/src/smithy_json/_private/serializers.py @@ -21,18 +21,17 @@ from smithy_core.traits import JSONNameTrait, TimestampFormatTrait from smithy_core.types import TimestampFormat -from . import Flushable, JSONSettings +from ..settings import JSONSettings +from . import Flushable _INF: float = float("inf") _NEG_INF: float = float("-inf") class JSONShapeSerializer(ShapeSerializer): - def __init__( - self, sink: BytesWriter, *, settings: JSONSettings | None = None - ) -> None: + def __init__(self, sink: BytesWriter, settings: JSONSettings) -> None: self._stream = StreamingJSONEncoder(sink) - self._settings = settings or JSONSettings() + self._settings = settings def begin_struct( self, schema: "Schema" diff --git a/packages/smithy-json/src/smithy_json/settings.py b/packages/smithy-json/src/smithy_json/settings.py new file mode 100644 index 00000000..6ebde62c --- /dev/null +++ b/packages/smithy-json/src/smithy_json/settings.py @@ -0,0 +1,31 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from smithy_core.types import TimestampFormat + +if TYPE_CHECKING: + from ._private.documents import JSONDocument + + +@dataclass(slots=True) +class JSONSettings: + """Settings for the JSON codec.""" + + document_class: type["JSONDocument"] + """The document class to deserialize to.""" + + use_json_name: bool = True + """Whether the codec should use `smithy.api#jsonName` trait, if present.""" + + use_timestamp_format: bool = True + """Whether the codec should use the `smithy.api#timestampFormat` trait, if + present.""" + + default_timestamp_format: TimestampFormat = TimestampFormat.DATE_TIME + """The default timestamp format to use if the `smithy.api#timestampFormat` trait is + not enabled or not present.""" + + default_namespace: str | None = None + """The default namespace to use when determining a document's discriminator.""" diff --git a/packages/smithy-json/tests/unit/test_deserializers.py b/packages/smithy-json/tests/unit/test_deserializers.py index cee2fcd2..f67e9c38 100644 --- a/packages/smithy-json/tests/unit/test_deserializers.py +++ b/packages/smithy-json/tests/unit/test_deserializers.py @@ -9,12 +9,13 @@ BIG_DECIMAL, BLOB, BOOLEAN, + DOCUMENT, FLOAT, INTEGER, STRING, TIMESTAMP, ) -from smithy_json import JSONCodec +from smithy_json import JSONCodec, JSONDocument from . import ( JSON_SERDE_CASES, @@ -88,3 +89,13 @@ def _read_optional_map(k: str, d: ShapeDeserializer): assert actual_value == expected_value else: assert actual == expected + + +class CustomDocument(JSONDocument): + pass + + +def test_uses_custom_document() -> None: + codec = JSONCodec(document_class=CustomDocument) + actual = codec.create_deserializer(b'{"foo": "bar"}').read_document(DOCUMENT) + assert isinstance(actual, CustomDocument) From 9ee1229ae2e5dbd7a06f397a29af9f2c616ca3a3 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Fri, 9 May 2025 15:32:03 +0200 Subject: [PATCH 4/5] Handle AWS error codes in documents --- .../src/smithy_aws_core/aio/protocols.py | 29 ++++--- .../src/smithy_aws_core/utils.py | 32 ++++++++ .../tests/unit/aio/test_protocols.py | 57 +++++++++++++- .../smithy-aws-core/tests/unit/test_utils.py | 78 +++++++++++++++++++ 4 files changed, 186 insertions(+), 10 deletions(-) create mode 100644 packages/smithy-aws-core/src/smithy_aws_core/utils.py create mode 100644 packages/smithy-aws-core/tests/unit/test_utils.py diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index 51904697..0b4d56c3 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -1,13 +1,15 @@ from typing import Any, Final from smithy_core.codecs import Codec +from smithy_core.exceptions import DiscriminatorError from smithy_core.schemas import APIOperation -from smithy_core.shapes import ShapeID +from smithy_core.shapes import ShapeID, ShapeType from smithy_http.aio.interfaces import HTTPErrorIdentifier, HTTPResponse from smithy_http.aio.protocols import HttpBindingClientProtocol -from smithy_json import JSONCodec +from smithy_json import JSONCodec, JSONDocument from ..traits import RestJson1Trait +from ..utils import parse_document_discriminator, parse_error_code class AWSErrorIdentifier(HTTPErrorIdentifier): @@ -24,20 +26,29 @@ def identify( error_field = response.fields[self._HEADER_KEY] code = error_field.values[0] if len(error_field.values) > 0 else None - if not code: - return None + if code is not None: + return parse_error_code(code, operation.schema.id.namespace) + return None + - code = code.split(":")[0] - if "#" in code: - return ShapeID(code) - return ShapeID.from_parts(name=code, namespace=operation.schema.id.namespace) +class AWSJSONDocument(JSONDocument): + @property + def discriminator(self) -> ShapeID: + if self.shape_type is ShapeType.STRUCTURE: + return self._schema.id + parsed = parse_document_discriminator(self, self._settings.default_namespace) + if parsed is None: + raise DiscriminatorError( + f"Unable to parse discriminator for {self.shape_type} docuemnt." + ) + return parsed class RestJsonClientProtocol(HttpBindingClientProtocol): """An implementation of the aws.protocols#restJson1 protocol.""" _id: Final = RestJson1Trait.id - _codec: Final = JSONCodec() + _codec: Final = JSONCodec(document_class=AWSJSONDocument) _contentType: Final = "application/json" _error_identifier: Final = AWSErrorIdentifier() diff --git a/packages/smithy-aws-core/src/smithy_aws_core/utils.py b/packages/smithy-aws-core/src/smithy_aws_core/utils.py new file mode 100644 index 00000000..940160e0 --- /dev/null +++ b/packages/smithy-aws-core/src/smithy_aws_core/utils.py @@ -0,0 +1,32 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from smithy_core.documents import Document +from smithy_core.shapes import ShapeID, ShapeType + + +def parse_document_discriminator( + document: Document, default_namespace: str | None +) -> ShapeID | None: + if document.shape_type is ShapeType.MAP: + map_document = document.as_map() + code = map_document.get("__type") + if code is None: + code = map_document.get("code") + if code is not None and code.shape_type is ShapeType.STRING: + return parse_error_code(code.as_string(), default_namespace) + + return None + + +def parse_error_code(code: str, default_namespace: str | None) -> ShapeID | None: + if not code: + return None + + code = code.split(":")[0] + if "#" in code: + return ShapeID(code) + + if not code or not default_namespace: + return None + + return ShapeID.from_parts(name=code, namespace=default_namespace) diff --git a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py index 82cf7d1e..7b767a08 100644 --- a/packages/smithy-aws-core/tests/unit/aio/test_protocols.py +++ b/packages/smithy-aws-core/tests/unit/aio/test_protocols.py @@ -4,11 +4,13 @@ from unittest.mock import Mock import pytest -from smithy_aws_core.aio.protocols import AWSErrorIdentifier +from smithy_aws_core.aio.protocols import AWSErrorIdentifier, AWSJSONDocument +from smithy_core.exceptions import DiscriminatorError from smithy_core.schemas import APIOperation, Schema from smithy_core.shapes import ShapeID, ShapeType from smithy_http import Fields, tuples_to_fields from smithy_http.aio import HTTPResponse +from smithy_json import JSONSettings @pytest.mark.parametrize( @@ -24,6 +26,7 @@ "com.test#FooError", ), ("", None), + (":", None), (None, None), ], ) @@ -42,3 +45,55 @@ def test_aws_error_identifier(header: str | None, expected: ShapeID | None) -> N actual = error_identifier.identify(operation=operation, response=http_response) assert actual == expected + + +@pytest.mark.parametrize( + "document, expected", + [ + ({"__type": "FooError"}, "com.test#FooError"), + ({"__type": "com.test#FooError"}, "com.test#FooError"), + ( + { + "__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"code": "FooError"}, "com.test#FooError"), + ({"code": "com.test#FooError"}, "com.test#FooError"), + ( + { + "code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"__type": "FooError", "code": "BarError"}, "com.test#FooError"), + ("FooError", None), + ({"__type": None}, None), + ({"__type": ""}, None), + ({"__type": ":"}, None), + ], +) +def test_aws_json_document_discriminator( + document: dict[str, str], expected: ShapeID | None +) -> None: + settings = JSONSettings( + document_class=AWSJSONDocument, default_namespace="com.test" + ) + if expected is None: + with pytest.raises(DiscriminatorError): + AWSJSONDocument(document, settings=settings).discriminator + else: + discriminator = AWSJSONDocument(document, settings=settings).discriminator + assert discriminator == expected diff --git a/packages/smithy-aws-core/tests/unit/test_utils.py b/packages/smithy-aws-core/tests/unit/test_utils.py new file mode 100644 index 00000000..6927a2fc --- /dev/null +++ b/packages/smithy-aws-core/tests/unit/test_utils.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from smithy_aws_core.utils import parse_document_discriminator, parse_error_code +from smithy_core.documents import Document +from smithy_core.shapes import ShapeID + + +@pytest.mark.parametrize( + "document, expected", + [ + ({"__type": "FooError"}, "com.test#FooError"), + ({"__type": "com.test#FooError"}, "com.test#FooError"), + ( + { + "__type": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "__type": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"code": "FooError"}, "com.test#FooError"), + ({"code": "com.test#FooError"}, "com.test#FooError"), + ( + { + "code": "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/" + }, + "com.test#FooError", + ), + ( + { + "code": "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate" + }, + "com.test#FooError", + ), + ({"__type": "FooError", "code": "BarError"}, "com.test#FooError"), + ("FooError", None), + ({"__type": None}, None), + ({"__type": ""}, None), + ({"__type": ":"}, None), + ], +) +def test_aws_json_document_discriminator( + document: dict[str, str], expected: ShapeID | None +) -> None: + actual = parse_document_discriminator(Document(document), "com.test") + assert actual == expected + + +@pytest.mark.parametrize( + "code, expected", + [ + ("FooError", "com.test#FooError"), + ( + "FooError:http://internal.amazon.com/coral/com.amazon.coral.validate/", + "com.test#FooError", + ), + ( + "com.test#FooError:http://internal.amazon.com/coral/com.amazon.coral.validate", + "com.test#FooError", + ), + ("", None), + (":", None), + ], +) +def test_parse_error_code(code: str, expected: ShapeID | None) -> None: + actual = parse_error_code(code, "com.test") + assert actual == expected + + +def test_parse_error_code_without_default_namespace() -> None: + actual = parse_error_code("FooError", None) + assert actual is None From cc0b232bcf05797e6a46382943dfc455b9af646a Mon Sep 17 00:00:00 2001 From: Jordon Phillips Date: Mon, 12 May 2025 14:54:41 +0200 Subject: [PATCH 5/5] Fix typo in protocols docs Co-authored-by: Nate Prewitt --- packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py index 0b4d56c3..ba0e8be8 100644 --- a/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py +++ b/packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py @@ -39,7 +39,7 @@ def discriminator(self) -> ShapeID: parsed = parse_document_discriminator(self, self._settings.default_namespace) if parsed is None: raise DiscriminatorError( - f"Unable to parse discriminator for {self.shape_type} docuemnt." + f"Unable to parse discriminator for {self.shape_type} document." ) return parsed