Skip to content

Commit 791b7b7

Browse files
Address schema serialization bug (#136)
After more testing we found a small bug related to schema serialization, Enums, and Pydantic. This PR addresses this and expands the testing.
1 parent 66ef8b8 commit 791b7b7

File tree

2 files changed

+34
-33
lines changed

2 files changed

+34
-33
lines changed

redisvl/schema/schema.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,28 @@
1414
SCHEMA_VERSION = "0.1.0"
1515

1616

17+
def custom_dict(model: BaseModel) -> Dict[str, Any]:
18+
"""
19+
Custom serialization function that converts a Pydantic model to a dict,
20+
serializing Enum fields to their values, and handling nested models and lists.
21+
"""
22+
23+
def serialize_item(item):
24+
if isinstance(item, Enum):
25+
return item.value.lower()
26+
elif isinstance(item, dict):
27+
return {key: serialize_item(value) for key, value in item.items()}
28+
elif isinstance(item, list):
29+
return [serialize_item(element) for element in item]
30+
else:
31+
return item
32+
33+
serialized_data = model.dict(exclude_none=True)
34+
for key, value in serialized_data.items():
35+
serialized_data[key] = serialize_item(value)
36+
return serialized_data
37+
38+
1739
class StorageType(Enum):
1840
"""
1941
Enumeration for the storage types supported in Redis.
@@ -63,14 +85,6 @@ class IndexInfo(BaseModel):
6385
storage_type: StorageType = StorageType.HASH
6486
"""The storage type used in Redis (e.g., 'hash' or 'json')."""
6587

66-
def dict(self, *args, **kwargs) -> Dict[str, Any]:
67-
return {
68-
"name": self.name,
69-
"prefix": self.prefix,
70-
"key_separator": self.key_separator,
71-
"storage_type": self.storage_type.value,
72-
}
73-
7488

7589
class IndexSchema(BaseModel):
7690
"""A schema definition for a search index in Redis, used in RedisVL for
@@ -428,12 +442,13 @@ def generate_fields(
428442
return fields
429443

430444
def to_dict(self) -> Dict[str, Any]:
431-
"""Convert the index schema to a dictionary.
445+
"""Serialize the index schema model to a dictionary, handling Enums
446+
and other special cases properly.
432447
433448
Returns:
434449
Dict[str, Any]: The index schema as a dictionary.
435450
"""
436-
dict_schema = self.dict(exclude_none=True)
451+
dict_schema = custom_dict(self)
437452
# cast fields back to a pure list
438453
dict_schema["fields"] = [
439454
field for field_name, field in dict_schema["fields"].items()

tests/unit/test_schema.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import pytest
55

6-
from redisvl.schema.fields import NumericField, TagField, TextField
7-
from redisvl.schema.schema import IndexSchema, StorageType
6+
from redisvl.schema.fields import TagField, TextField
7+
from redisvl.schema.schema import IndexSchema, StorageType, custom_dict
88

99

1010
def get_base_path():
@@ -16,6 +16,12 @@ def create_sample_index_schema():
1616
sample_fields = [
1717
{"name": "example_text", "type": "text", "attrs": {"sortable": False}},
1818
{"name": "example_numeric", "type": "numeric", "attrs": {"sortable": True}},
19+
{"name": "example_tag", "type": "tag", "attrs": {"sortable": True}},
20+
{
21+
"name": "example_vector",
22+
"type": "vector",
23+
"attrs": {"dims": 1024, "algorithm": "flat"},
24+
},
1925
]
2026
return IndexSchema.from_dict({"index": {"name": "test"}, "fields": sample_fields})
2127

@@ -89,26 +95,6 @@ def test_remove_field():
8995
assert "example_text" not in index_schema.field_names
9096

9197

92-
def test_schema_compare():
93-
"""Test schema comparisons."""
94-
schema_1 = IndexSchema.from_dict({"index": {"name": "test"}})
95-
# manually add the same fields as the helper method provides below
96-
schema_1.add_fields(
97-
[
98-
{"name": "example_text", "type": "text", "attrs": {"sortable": False}},
99-
{"name": "example_numeric", "type": "numeric", "attrs": {"sortable": True}},
100-
]
101-
)
102-
103-
assert "example_text" in schema_1.fields
104-
assert "example_numeric" in schema_1.fields
105-
106-
schema_2 = create_sample_index_schema()
107-
assert schema_1.fields == schema_2.fields
108-
assert schema_1.index.name == schema_2.index.name
109-
assert schema_1.to_dict() == schema_2.to_dict()
110-
111-
11298
def test_generate_fields():
11399
"""Test field generation."""
114100
sample = {"name": "John", "age": 30, "tags": ["test", "test2"]}
@@ -126,7 +112,7 @@ def test_to_dict():
126112
index_dict = index_schema.to_dict()
127113
assert index_dict["index"]["name"] == "test"
128114
assert isinstance(index_dict["fields"], list)
129-
assert len(index_dict["fields"]) == 2 == len(index_schema.fields)
115+
assert len(index_dict["fields"]) == 4 == len(index_schema.fields)
130116

131117

132118
def test_from_dict():

0 commit comments

Comments
 (0)