Skip to content

Commit 97bded0

Browse files
authored
Handle empty context properly for Avro deserializer (#2035)
* update * lint * address feedback
1 parent 3e30757 commit 97bded0

File tree

4 files changed

+62
-21
lines changed

4 files changed

+62
-21
lines changed

src/confluent_kafka/schema_registry/__init__.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
SchemaReference,
3737
ServerConfig
3838
)
39-
from ..serialization import SerializationError, MessageField
39+
from ..serialization import SerializationError, MessageField, SerializationContext
4040

4141
_KEY_SCHEMA_ID = "__key_schema_id"
4242
_VALUE_SCHEMA_ID = "__value_schema_id"
@@ -128,7 +128,7 @@ def reference_subject_name_strategy(ctx, schema_ref: SchemaReference) -> Optiona
128128
return schema_ref.name if schema_ref is not None else None
129129

130130

131-
def header_schema_id_serializer(payload: bytes, ctx, schema_id) -> bytes:
131+
def header_schema_id_serializer(payload: bytes, ctx: Optional[SerializationContext], schema_id) -> bytes:
132132
"""
133133
Serializes the schema guid into the header.
134134
@@ -141,6 +141,9 @@ def header_schema_id_serializer(payload: bytes, ctx, schema_id) -> bytes:
141141
Returns:
142142
bytes: The payload
143143
"""
144+
if ctx is None:
145+
raise SerializationError("SerializationContext is required for header_schema_id_serializer")
146+
144147
headers = ctx.headers
145148
if headers is None:
146149
raise SerializationError("Missing headers")
@@ -171,7 +174,7 @@ def prefix_schema_id_serializer(payload: bytes, ctx, schema_id) -> bytes:
171174
return schema_id.id_to_bytes() + payload
172175

173176

174-
def dual_schema_id_deserializer(payload: bytes, ctx, schema_id) -> io.BytesIO:
177+
def dual_schema_id_deserializer(payload: bytes, ctx: Optional[SerializationContext], schema_id) -> io.BytesIO:
175178
"""
176179
Deserializes the schema id by first checking the header, then the payload prefix.
177180
@@ -184,22 +187,28 @@ def dual_schema_id_deserializer(payload: bytes, ctx, schema_id) -> io.BytesIO:
184187
Returns:
185188
bytes: The payload
186189
"""
187-
headers = ctx.headers
188-
header_key = _KEY_SCHEMA_ID if ctx.field == MessageField.KEY else _VALUE_SCHEMA_ID
189-
if headers is not None:
190-
header_value = None
191-
if isinstance(headers, list):
192-
# look for header_key in headers
193-
for header in headers:
194-
if header[0] == header_key:
195-
header_value = header[1]
196-
break
197-
elif isinstance(headers, dict):
198-
header_value = headers.get(header_key, None)
199-
if header_value is not None:
200-
schema_id.from_bytes(io.BytesIO(header_value))
201-
return io.BytesIO(payload)
202-
return schema_id.from_bytes(io.BytesIO(payload))
190+
# Look for schema ID in headers
191+
header_value = None
192+
193+
if ctx is not None:
194+
headers = ctx.headers
195+
if headers is not None:
196+
header_key = _KEY_SCHEMA_ID if ctx.field == MessageField.KEY else _VALUE_SCHEMA_ID
197+
if isinstance(headers, list):
198+
# look for header_key in headers
199+
for header in headers:
200+
if header[0] == header_key:
201+
header_value = header[1]
202+
break
203+
elif isinstance(headers, dict):
204+
header_value = headers.get(header_key, None)
205+
206+
# Parse schema ID from determined source and return appropriate payload
207+
if header_value is not None:
208+
schema_id.from_bytes(io.BytesIO(header_value))
209+
return io.BytesIO(payload) # Return full payload when schema ID is in header
210+
else:
211+
return schema_id.from_bytes(io.BytesIO(payload)) # Parse from payload, return remainder
203212

204213

205214
def prefix_schema_id_deserializer(payload: bytes, ctx, schema_id) -> io.BytesIO:

src/confluent_kafka/schema_registry/_async/schema_registry_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,9 @@ async def get_schema_by_guid(
763763
if schema is not None:
764764
return schema
765765

766+
query = {}
766767
if fmt is not None:
767-
query = {'format': fmt}
768+
query['format'] = fmt
768769
response = await self._rest_client.get('schemas/guids/{}'.format(guid), query)
769770

770771
registered_schema = RegisteredSchema.from_dict(response)

src/confluent_kafka/schema_registry/_sync/schema_registry_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,8 +763,9 @@ def get_schema_by_guid(
763763
if schema is not None:
764764
return schema
765765

766+
query = {}
766767
if fmt is not None:
767-
query = {'format': fmt}
768+
query['format'] = fmt
768769
response = self._rest_client.get('schemas/guids/{}'.format(guid), query)
769770

770771
registered_schema = RegisteredSchema.from_dict(response)

tests/schema_registry/test_schema_id.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616
# limitations under the License.
1717
#
1818
import io
19+
import pytest
20+
1921
from confluent_kafka.schema_registry.serde import SchemaId
22+
from confluent_kafka.schema_registry import (
23+
dual_schema_id_deserializer,
24+
header_schema_id_serializer,
25+
SerializationError
26+
)
2027

2128

2229
def test_schema_guid():
@@ -71,3 +78,26 @@ def test_schema_id_with_message_indexes():
7178
assert indexes == [1, 2, 3]
7279
output = schema_id.id_to_bytes()
7380
assert output == input
81+
82+
83+
def test_dual_schema_id_deserializer_handles_none_context():
84+
"""
85+
Ensures dual_schema_id_deserializer handles None SerializationContext properly.
86+
"""
87+
schema_id = SchemaId("AVRO")
88+
test_data = b'\x00\x00\x00\x00\x01' # Valid schema ID format
89+
90+
result = dual_schema_id_deserializer(test_data, ctx=None, schema_id=schema_id)
91+
92+
# Verify it returns BytesIO and parsed the schema ID
93+
assert isinstance(result, io.BytesIO)
94+
assert schema_id.id == 1
95+
96+
97+
def test_header_schema_id_serializer_handles_none_context():
98+
"""
99+
Ensures header_schema_id_serializer handles None SerializationContext properly.
100+
"""
101+
# schema_id won't be used since function raises error early when ctx=None
102+
with pytest.raises(SerializationError, match="SerializationContext is required"):
103+
header_schema_id_serializer(b"test_payload", ctx=None, schema_id=None)

0 commit comments

Comments
 (0)