Skip to content

Commit 8561a1c

Browse files
JordonPhillipsnateprewitt
authored andcommitted
Use AsyncBytesProvider for event streams
# Conflicts: # packages/smithy-http/src/smithy_http/serializers.py
1 parent f5b9557 commit 8561a1c

File tree

3 files changed

+146
-42
lines changed

3 files changed

+146
-42
lines changed

packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ def create_event_publisher[
148148
),
149149
)
150150

151-
# The HTTP body must be an async writeable. The HTTP client bindings are
152-
# responsible for ensuring this is the case. The CRT bindings, for example,
153-
# will set the body to an instance of BufferableByteStream.
151+
# The HTTP body must be an async writeable. The HTTP serializers are responsible
152+
# for ensuring this.
154153
body = request.body
155154
if not isinstance(body, AsyncWriter) or not iscoroutinefunction(body.write):
156155
raise UnsupportedStreamError(

packages/smithy-http/src/smithy_http/serializers.py

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from asyncio import iscoroutinefunction
44
from base64 import b64encode
5-
from collections.abc import Callable, Iterator
5+
from collections.abc import Callable, Iterator, Sized
66
from contextlib import contextmanager
77
from datetime import datetime
88
from decimal import Decimal
99
from io import BytesIO
10-
from typing import TYPE_CHECKING, Any
10+
from typing import TYPE_CHECKING
1111
from urllib.parse import quote as urlquote
1212

1313
from smithy_core import URI
14-
from smithy_core.aio.types import AsyncBytesReader
14+
from smithy_core.aio.types import AsyncBytesProvider, AsyncBytesReader
1515
from smithy_core.codecs import Codec
1616
from smithy_core.exceptions import SerializationError
1717
from smithy_core.schemas import Schema
@@ -81,7 +81,7 @@ def __init__(
8181

8282
@contextmanager
8383
def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
84-
payload: Any
84+
payload: AsyncBytesReader | AsyncBytesProvider
8585
binding_serializer: HTTPRequestBindingSerializer
8686

8787
host_prefix = ""
@@ -93,7 +93,17 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
9393
content_length_required = False
9494

9595
binding_matcher = RequestBindingMatcher(schema)
96-
if (payload_member := binding_matcher.payload_member) is not None:
96+
if binding_matcher.event_stream_member is not None:
97+
payload = AsyncBytesProvider()
98+
content_type = "application/vnd.amazon.eventstream"
99+
binding_serializer = HTTPRequestBindingSerializer(
100+
SpecificShapeSerializer(),
101+
self._http_trait.path,
102+
host_prefix,
103+
binding_matcher,
104+
)
105+
yield binding_serializer
106+
elif (payload_member := binding_matcher.payload_member) is not None:
97107
content_length_required = RequiresLengthTrait in payload_member
98108
if payload_member.shape_type in (
99109
ShapeType.BLOB,
@@ -115,31 +125,28 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
115125
binding_matcher,
116126
)
117127
yield binding_serializer
118-
payload = payload_serializer.payload or b""
119-
try:
120-
content_length = len(payload)
121-
except TypeError:
122-
pass
128+
if isinstance(payload_serializer.payload, Sized):
129+
content_length = len(payload_serializer.payload)
130+
payload = AsyncBytesReader(payload_serializer.payload or b"")
123131
else:
124132
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
125133
content_type = media_type.value
126-
payload = BytesIO()
127-
payload_serializer = self._payload_codec.create_serializer(payload)
134+
sync_payload = BytesIO()
135+
payload_serializer = self._payload_codec.create_serializer(sync_payload)
128136
binding_serializer = HTTPRequestBindingSerializer(
129137
payload_serializer,
130138
self._http_trait.path,
131139
host_prefix,
132140
binding_matcher,
133141
)
134142
yield binding_serializer
135-
content_length = payload.tell()
136-
payload.seek(0)
143+
content_length = sync_payload.tell()
144+
sync_payload.seek(0)
145+
payload = AsyncBytesReader(sync_payload)
137146
else:
138-
payload = BytesIO()
139-
payload_serializer = self._payload_codec.create_serializer(payload)
147+
sync_payload = BytesIO()
148+
payload_serializer = self._payload_codec.create_serializer(sync_payload)
140149
if binding_matcher.should_write_body(self._omit_empty_payload):
141-
if binding_matcher.event_stream_member is not None:
142-
content_type = "application/vnd.amazon.eventstream"
143150
with payload_serializer.begin_struct(schema) as body_serializer:
144151
binding_serializer = HTTPRequestBindingSerializer(
145152
body_serializer,
@@ -148,7 +155,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
148155
binding_matcher,
149156
)
150157
yield binding_serializer
151-
content_length = payload.tell()
158+
content_length = sync_payload.tell()
152159
else:
153160
content_type = None
154161
content_length = None
@@ -159,7 +166,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
159166
binding_matcher,
160167
)
161168
yield binding_serializer
162-
payload.seek(0)
169+
sync_payload.seek(0)
170+
payload = AsyncBytesReader(sync_payload)
163171

164172
headers = binding_serializer.header_serializer.headers
165173
if content_type is not None:
@@ -189,11 +197,13 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
189197
),
190198
),
191199
fields=fields,
192-
body=AsyncBytesReader(payload),
200+
body=payload,
193201
)
194202

195203

196-
def _compute_content_length(payload: Any) -> int | None:
204+
def _compute_content_length(
205+
payload: AsyncBytesReader | AsyncBytesProvider,
206+
) -> int | None:
197207
if (tell := getattr(payload, "tell", None)) is not None and not iscoroutinefunction(
198208
tell
199209
):
@@ -205,7 +215,9 @@ def _compute_content_length(payload: Any) -> int | None:
205215
return None
206216

207217

208-
def _seek(payload: Any, pos: int, whence: int = 0) -> None:
218+
def _seek(
219+
payload: AsyncBytesReader | AsyncBytesProvider, pos: int, whence: int = 0
220+
) -> None:
209221
if (seek := getattr(payload, "seek", None)) is not None and not iscoroutinefunction(
210222
seek
211223
):
@@ -278,15 +290,22 @@ def __init__(
278290

279291
@contextmanager
280292
def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
281-
payload: Any
293+
payload: AsyncBytesReader | AsyncBytesProvider
282294
binding_serializer: HTTPResponseBindingSerializer
283295

284296
content_type: str | None = self._payload_codec.media_type
285297
content_length: int | None = None
286298
content_length_required = False
287299

288300
binding_matcher = ResponseBindingMatcher(schema)
289-
if (payload_member := binding_matcher.payload_member) is not None:
301+
if binding_matcher.event_stream_member is not None:
302+
payload = AsyncBytesProvider()
303+
content_type = "application/vnd.amazon.eventstream"
304+
binding_serializer = HTTPResponseBindingSerializer(
305+
SpecificShapeSerializer(), binding_matcher
306+
)
307+
yield binding_serializer
308+
elif (payload_member := binding_matcher.payload_member) is not None:
290309
content_length_required = RequiresLengthTrait in payload_member
291310
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
292311
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
@@ -300,25 +319,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
300319
payload_serializer, binding_matcher
301320
)
302321
yield binding_serializer
303-
payload = payload_serializer.payload or b""
304-
try:
305-
content_length = len(payload)
306-
except TypeError:
307-
pass
322+
if isinstance(payload_serializer.payload, Sized):
323+
content_length = len(payload_serializer.payload)
324+
payload = AsyncBytesReader(payload_serializer.payload or b"")
308325
else:
309326
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
310327
content_type = media_type.value
311-
payload = BytesIO()
312-
payload_serializer = self._payload_codec.create_serializer(payload)
328+
sync_payload = BytesIO()
329+
payload_serializer = self._payload_codec.create_serializer(sync_payload)
313330
binding_serializer = HTTPResponseBindingSerializer(
314331
payload_serializer, binding_matcher
315332
)
316333
yield binding_serializer
317-
content_length = payload.tell()
318-
payload.seek(0)
334+
content_length = sync_payload.tell()
335+
sync_payload.seek(0)
336+
payload = AsyncBytesReader(sync_payload)
319337
else:
320-
payload = BytesIO()
321-
payload_serializer = self._payload_codec.create_serializer(payload)
338+
sync_payload = BytesIO()
339+
payload_serializer = self._payload_codec.create_serializer(sync_payload)
322340
if binding_matcher.should_write_body(self._omit_empty_payload):
323341
if binding_matcher.event_stream_member is not None:
324342
content_type = "application/vnd.amazon.eventstream"
@@ -327,7 +345,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
327345
body_serializer, binding_matcher
328346
)
329347
yield binding_serializer
330-
content_length = payload.tell()
348+
content_length = sync_payload.tell()
331349
else:
332350
content_type = None
333351
content_length = None
@@ -336,7 +354,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
336354
binding_matcher,
337355
)
338356
yield binding_serializer
339-
payload.seek(0)
357+
sync_payload.seek(0)
358+
payload = AsyncBytesReader(sync_payload)
340359

341360
headers = binding_serializer.header_serializer.headers
342361
if content_type is not None:
@@ -364,7 +383,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
364383

365384
self.result = _HTTPResponse(
366385
fields=tuples_to_fields(binding_serializer.header_serializer.headers),
367-
body=AsyncBytesReader(payload),
386+
body=payload,
368387
status=status,
369388
)
370389

packages/smithy-http/tests/unit/test_serializers.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,54 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
721721
return cls(**kwargs)
722722

723723

724+
@dataclass
725+
class HTTPEventStreamPayload:
726+
header: str | None = None
727+
728+
EVENT_SCHEMA: ClassVar[Schema] = Schema.collection(
729+
id=ShapeID("com.smithy#Event"), members={"message": {"target": STRING}}
730+
)
731+
EVENT_STREAM_SCHEMA: ClassVar[Schema] = Schema.collection(
732+
id=ShapeID("com.smithy#EventStream"),
733+
shape_type=ShapeType.UNION,
734+
traits=[StreamingTrait()],
735+
members={"messageEvent": {"target": EVENT_SCHEMA}},
736+
)
737+
ID: ClassVar[ShapeID] = ShapeID("com.smithy#HTTPEventStreamPayload")
738+
SCHEMA: ClassVar[Schema] = Schema.collection(
739+
id=ID,
740+
members={
741+
"header": {
742+
"target": STRING,
743+
"traits": [HTTPHeaderTrait("header")],
744+
},
745+
"stream": {"target": EVENT_STREAM_SCHEMA, "traits": [HTTPPayloadTrait()]},
746+
},
747+
)
748+
749+
def serialize(self, serializer: ShapeSerializer) -> None:
750+
with serializer.begin_struct(self.SCHEMA) as s:
751+
self.serialize_members(s)
752+
753+
def serialize_members(self, serializer: ShapeSerializer) -> None:
754+
if self.header is not None:
755+
serializer.write_string(self.SCHEMA.members["header"], self.header)
756+
757+
@classmethod
758+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
759+
kwargs: dict[str, Any] = {}
760+
761+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
762+
match schema.expect_member_index():
763+
case 0:
764+
kwargs["header"] = de.read_string(cls.SCHEMA.members["header"])
765+
case _:
766+
raise Exception(f"Unexpected schema: {schema}")
767+
768+
deserializer.read_struct(schema=cls.SCHEMA, consumer=_consumer)
769+
return cls(**kwargs)
770+
771+
724772
@dataclass
725773
class HTTPStringLabel:
726774
label: str
@@ -1919,3 +1967,41 @@ async def test_deserialize_http_response_with_async_stream() -> None:
19191967
)
19201968
actual = HTTPStreamingPayload.deserialize(deserializer)
19211969
assert actual == HTTPStreamingPayload(stream)
1970+
1971+
1972+
async def test_serialize_request_event_stream_creates_writeable_body() -> None:
1973+
serializer = HTTPRequestSerializer(
1974+
payload_codec=JSONCodec(),
1975+
http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/"}),
1976+
)
1977+
1978+
request = HTTPEventStreamPayload(header="foo")
1979+
request.serialize(serializer)
1980+
1981+
actual = serializer.result
1982+
assert actual is not None
1983+
1984+
body_write = getattr(actual.body, "write", None)
1985+
assert body_write is not None and iscoroutinefunction(body_write)
1986+
1987+
assert "header" in actual.fields
1988+
assert actual.fields["header"].as_string() == "foo"
1989+
1990+
1991+
async def test_serialize_response_event_stream_creates_writeable_body() -> None:
1992+
serializer = HTTPResponseSerializer(
1993+
payload_codec=JSONCodec(),
1994+
http_trait=HTTPTrait({"method": "POST", "code": 200, "uri": "/"}),
1995+
)
1996+
1997+
response = HTTPEventStreamPayload(header="foo")
1998+
response.serialize(serializer)
1999+
2000+
actual = serializer.result
2001+
assert actual is not None
2002+
2003+
body_write = getattr(actual.body, "write", None)
2004+
assert body_write is not None and iscoroutinefunction(body_write)
2005+
2006+
assert "header" in actual.fields
2007+
assert actual.fields["header"].as_string() == "foo"

0 commit comments

Comments
 (0)