3636 SchemaReference ,
3737 ServerConfig
3838)
39- from ..serialization import SerializationError , MessageField
39+ from ..serialization import SerializationError , MessageField , SerializationContext
4040
4141_KEY_SCHEMA_ID = "__key_schema_id"
4242_VALUE_SCHEMA_ID = "__value_schema_id"
@@ -128,7 +128,7 @@ def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optiona
128128 return schema_ref .name if schema_ref is not None else None
129129
130130
131- def header_schema_id_serializer (payload : bytes , ctx , schema_id ) -> bytes :
131+ def header_schema_id_serializer (payload : bytes , ctx : Optional [ SerializationContext ] , schema_id ) -> bytes :
132132 """
133133 Serializes the schema guid into the header.
134134
@@ -141,6 +141,9 @@ def header_schema_id_serializer(payload: bytes, ctx, schema_id) -> bytes:
141141 Returns:
142142 bytes: The payload
143143 """
144+ if ctx is None :
145+ raise SerializationError ("SerializationContext is required for header_schema_id_serializer" )
146+
144147 headers = ctx .headers
145148 if headers is None :
146149 raise SerializationError ("Missing headers" )
@@ -171,7 +174,7 @@ def prefix_schema_id_serializer(payload: bytes, ctx, schema_id) -> bytes:
171174 return schema_id .id_to_bytes () + payload
172175
173176
174- def dual_schema_id_deserializer (payload : bytes , ctx , schema_id ) -> io .BytesIO :
177+ def dual_schema_id_deserializer (payload : bytes , ctx : Optional [ SerializationContext ] , schema_id ) -> io .BytesIO :
175178 """
176179 Deserializes the schema id by first checking the header, then the payload prefix.
177180
@@ -184,22 +187,28 @@ def dual_schema_id_deserializer(payload: bytes, ctx, schema_id) -> io.BytesIO:
184187 Returns:
185188 bytes: The payload
186189 """
187- headers = ctx .headers
188- header_key = _KEY_SCHEMA_ID if ctx .field == MessageField .KEY else _VALUE_SCHEMA_ID
189- if headers is not None :
190- header_value = None
191- if isinstance (headers , list ):
192- # look for header_key in headers
193- for header in headers :
194- if header [0 ] == header_key :
195- header_value = header [1 ]
196- break
197- elif isinstance (headers , dict ):
198- header_value = headers .get (header_key , None )
199- if header_value is not None :
200- schema_id .from_bytes (io .BytesIO (header_value ))
201- return io .BytesIO (payload )
202- return schema_id .from_bytes (io .BytesIO (payload ))
190+ # Look for schema ID in headers
191+ header_value = None
192+
193+ if ctx is not None :
194+ headers = ctx .headers
195+ if headers is not None :
196+ header_key = _KEY_SCHEMA_ID if ctx .field == MessageField .KEY else _VALUE_SCHEMA_ID
197+ if isinstance (headers , list ):
198+ # look for header_key in headers
199+ for header in headers :
200+ if header [0 ] == header_key :
201+ header_value = header [1 ]
202+ break
203+ elif isinstance (headers , dict ):
204+ header_value = headers .get (header_key , None )
205+
206+ # Parse schema ID from determined source and return appropriate payload
207+ if header_value is not None :
208+ schema_id .from_bytes (io .BytesIO (header_value ))
209+ return io .BytesIO (payload ) # Return full payload when schema ID is in header
210+ else :
211+ return schema_id .from_bytes (io .BytesIO (payload )) # Parse from payload, return remainder
203212
204213
205214def prefix_schema_id_deserializer (payload : bytes , ctx , schema_id ) -> io .BytesIO :
0 commit comments