22# SPDX-License-Identifier: Apache-2.0
33from asyncio import iscoroutinefunction
44from base64 import b64encode
5- from collections .abc import Callable , Iterator
5+ from collections .abc import Callable , Iterator , Sized
66from contextlib import contextmanager
77from datetime import datetime
88from decimal import Decimal
99from io import BytesIO
10- from typing import TYPE_CHECKING , Any
10+ from typing import TYPE_CHECKING
1111from urllib .parse import quote as urlquote
1212
1313from smithy_core import URI
14- from smithy_core .aio .types import AsyncBytesReader
14+ from smithy_core .aio .types import AsyncBytesProvider , AsyncBytesReader
1515from smithy_core .codecs import Codec
1616from smithy_core .exceptions import SerializationError
1717from smithy_core .schemas import Schema
@@ -81,7 +81,7 @@ def __init__(
8181
8282 @contextmanager
8383 def begin_struct (self , schema : Schema ) -> Iterator [ShapeSerializer ]:
84- payload : Any
84+ payload : AsyncBytesReader | AsyncBytesProvider
8585 binding_serializer : HTTPRequestBindingSerializer
8686
8787 host_prefix = ""
@@ -93,7 +93,17 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
9393 content_length_required = False
9494
9595 binding_matcher = RequestBindingMatcher (schema )
96- if (payload_member := binding_matcher .payload_member ) is not None :
96+ if binding_matcher .event_stream_member is not None :
97+ payload = AsyncBytesProvider ()
98+ content_type = "application/vnd.amazon.eventstream"
99+ binding_serializer = HTTPRequestBindingSerializer (
100+ SpecificShapeSerializer (),
101+ self ._http_trait .path ,
102+ host_prefix ,
103+ binding_matcher ,
104+ )
105+ yield binding_serializer
106+ elif (payload_member := binding_matcher .payload_member ) is not None :
97107 content_length_required = RequiresLengthTrait in payload_member
98108 if payload_member .shape_type in (
99109 ShapeType .BLOB ,
@@ -115,31 +125,28 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
115125 binding_matcher ,
116126 )
117127 yield binding_serializer
118- payload = payload_serializer .payload or b""
119- try :
120- content_length = len (payload )
121- except TypeError :
122- pass
128+ if isinstance (payload_serializer .payload , Sized ):
129+ content_length = len (payload_serializer .payload )
130+ payload = AsyncBytesReader (payload_serializer .payload or b"" )
123131 else :
124132 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
125133 content_type = media_type .value
126- payload = BytesIO ()
127- payload_serializer = self ._payload_codec .create_serializer (payload )
134+ sync_payload = BytesIO ()
135+ payload_serializer = self ._payload_codec .create_serializer (sync_payload )
128136 binding_serializer = HTTPRequestBindingSerializer (
129137 payload_serializer ,
130138 self ._http_trait .path ,
131139 host_prefix ,
132140 binding_matcher ,
133141 )
134142 yield binding_serializer
135- content_length = payload .tell ()
136- payload .seek (0 )
143+ content_length = sync_payload .tell ()
144+ sync_payload .seek (0 )
145+ payload = AsyncBytesReader (sync_payload )
137146 else :
138- payload = BytesIO ()
139- payload_serializer = self ._payload_codec .create_serializer (payload )
147+ sync_payload = BytesIO ()
148+ payload_serializer = self ._payload_codec .create_serializer (sync_payload )
140149 if binding_matcher .should_write_body (self ._omit_empty_payload ):
141- if binding_matcher .event_stream_member is not None :
142- content_type = "application/vnd.amazon.eventstream"
143150 with payload_serializer .begin_struct (schema ) as body_serializer :
144151 binding_serializer = HTTPRequestBindingSerializer (
145152 body_serializer ,
@@ -148,7 +155,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
148155 binding_matcher ,
149156 )
150157 yield binding_serializer
151- content_length = payload .tell ()
158+ content_length = sync_payload .tell ()
152159 else :
153160 content_type = None
154161 content_length = None
@@ -159,7 +166,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
159166 binding_matcher ,
160167 )
161168 yield binding_serializer
162- payload .seek (0 )
169+ sync_payload .seek (0 )
170+ payload = AsyncBytesReader (sync_payload )
163171
164172 headers = binding_serializer .header_serializer .headers
165173 if content_type is not None :
@@ -189,11 +197,13 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
189197 ),
190198 ),
191199 fields = fields ,
192- body = AsyncBytesReader ( payload ) ,
200+ body = payload ,
193201 )
194202
195203
196- def _compute_content_length (payload : Any ) -> int | None :
204+ def _compute_content_length (
205+ payload : AsyncBytesReader | AsyncBytesProvider ,
206+ ) -> int | None :
197207 if (tell := getattr (payload , "tell" , None )) is not None and not iscoroutinefunction (
198208 tell
199209 ):
@@ -205,7 +215,9 @@ def _compute_content_length(payload: Any) -> int | None:
205215 return None
206216
207217
208- def _seek (payload : Any , pos : int , whence : int = 0 ) -> None :
218+ def _seek (
219+ payload : AsyncBytesReader | AsyncBytesProvider , pos : int , whence : int = 0
220+ ) -> None :
209221 if (seek := getattr (payload , "seek" , None )) is not None and not iscoroutinefunction (
210222 seek
211223 ):
@@ -278,15 +290,22 @@ def __init__(
278290
279291 @contextmanager
280292 def begin_struct (self , schema : Schema ) -> Iterator [ShapeSerializer ]:
281- payload : Any
293+ payload : AsyncBytesReader | AsyncBytesProvider
282294 binding_serializer : HTTPResponseBindingSerializer
283295
284296 content_type : str | None = self ._payload_codec .media_type
285297 content_length : int | None = None
286298 content_length_required = False
287299
288300 binding_matcher = ResponseBindingMatcher (schema )
289- if (payload_member := binding_matcher .payload_member ) is not None :
301+ if binding_matcher .event_stream_member is not None :
302+ payload = AsyncBytesProvider ()
303+ content_type = "application/vnd.amazon.eventstream"
304+ binding_serializer = HTTPResponseBindingSerializer (
305+ SpecificShapeSerializer (), binding_matcher
306+ )
307+ yield binding_serializer
308+ elif (payload_member := binding_matcher .payload_member ) is not None :
290309 content_length_required = RequiresLengthTrait in payload_member
291310 if payload_member .shape_type in (ShapeType .BLOB , ShapeType .STRING ):
292311 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
@@ -300,25 +319,24 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
300319 payload_serializer , binding_matcher
301320 )
302321 yield binding_serializer
303- payload = payload_serializer .payload or b""
304- try :
305- content_length = len (payload )
306- except TypeError :
307- pass
322+ if isinstance (payload_serializer .payload , Sized ):
323+ content_length = len (payload_serializer .payload )
324+ payload = AsyncBytesReader (payload_serializer .payload or b"" )
308325 else :
309326 if (media_type := payload_member .get_trait (MediaTypeTrait )) is not None :
310327 content_type = media_type .value
311- payload = BytesIO ()
312- payload_serializer = self ._payload_codec .create_serializer (payload )
328+ sync_payload = BytesIO ()
329+ payload_serializer = self ._payload_codec .create_serializer (sync_payload )
313330 binding_serializer = HTTPResponseBindingSerializer (
314331 payload_serializer , binding_matcher
315332 )
316333 yield binding_serializer
317- content_length = payload .tell ()
318- payload .seek (0 )
334+ content_length = sync_payload .tell ()
335+ sync_payload .seek (0 )
336+ payload = AsyncBytesReader (sync_payload )
319337 else :
320- payload = BytesIO ()
321- payload_serializer = self ._payload_codec .create_serializer (payload )
338+ sync_payload = BytesIO ()
339+ payload_serializer = self ._payload_codec .create_serializer (sync_payload )
322340 if binding_matcher .should_write_body (self ._omit_empty_payload ):
323341 if binding_matcher .event_stream_member is not None :
324342 content_type = "application/vnd.amazon.eventstream"
@@ -327,7 +345,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
327345 body_serializer , binding_matcher
328346 )
329347 yield binding_serializer
330- content_length = payload .tell ()
348+ content_length = sync_payload .tell ()
331349 else :
332350 content_type = None
333351 content_length = None
@@ -336,7 +354,8 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
336354 binding_matcher ,
337355 )
338356 yield binding_serializer
339- payload .seek (0 )
357+ sync_payload .seek (0 )
358+ payload = AsyncBytesReader (sync_payload )
340359
341360 headers = binding_serializer .header_serializer .headers
342361 if content_type is not None :
@@ -364,7 +383,7 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
364383
365384 self .result = _HTTPResponse (
366385 fields = tuples_to_fields (binding_serializer .header_serializer .headers ),
367- body = AsyncBytesReader ( payload ) ,
386+ body = payload ,
368387 status = status ,
369388 )
370389
0 commit comments