Skip to content

Commit 09a65fc

Browse files
authored
Fix error due to wrong property type (#412)
* Fix error due to wrong property type * Ruff * Mypy * Remove unused parameters * Update CHANGELOG
1 parent 83002a2 commit 09a65fc

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
## Next
44

5+
### Fixed
6+
7+
- Fixed an edge case where the LLM can output a property with type 'map', which was causing errors during import as it is not a valid property type in Neo4j.
8+
9+
510
## 1.9.1
611

712
### Fixed

src/neo4j_graphrag/experimental/components/graph_pruning.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import enum
16+
import json
1617
import logging
1718
from typing import Optional, Any, TypeVar, Generic, Union
1819

@@ -391,11 +392,15 @@ def _enforce_properties(
391392
) -> dict[str, Any]:
392393
"""
393394
Enforce properties:
395+
- Ensure property type: for now, just prevent having invalid property types (e.g. map)
394396
- Filter out those that are not in schema (i.e., valid properties) if allowed properties is False.
395397
- Check that all required properties are present and not null.
396398
"""
397-
filtered_properties = self._filter_properties(
399+
type_safe_properties = self._ensure_property_types(
398400
item.properties,
401+
)
402+
filtered_properties = self._filter_properties(
403+
type_safe_properties,
399404
schema_item.properties,
400405
schema_item.additional_properties,
401406
item.token, # label or type
@@ -453,3 +458,19 @@ def _check_required_properties(
453458
if filtered_properties.get(req_prop) is None:
454459
missing_required_properties.append(req_prop)
455460
return missing_required_properties
461+
462+
def _ensure_property_types(
463+
self,
464+
filtered_properties: dict[str, Any],
465+
) -> dict[str, Any]:
466+
type_safe_properties = {}
467+
for prop_name, prop_value in filtered_properties.items():
468+
if isinstance(prop_value, dict):
469+
# just ensure the type will not raise error on insert, while preserving data
470+
type_safe_properties[prop_name] = json.dumps(prop_value, default=str)
471+
continue
472+
473+
# this is where we could check types of other properties
474+
# but keep it simple for now
475+
type_safe_properties[prop_name] = prop_value
476+
return type_safe_properties

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def validate_additional_properties(self) -> Self:
111111
)
112112
return self
113113

114+
def property_type_from_name(self, name: str) -> Optional[PropertyType]:
115+
for prop in self.properties:
116+
if prop.name == name:
117+
return prop
118+
return None
119+
114120

115121
class RelationshipType(BaseModel):
116122
"""
@@ -141,6 +147,12 @@ def validate_additional_properties(self) -> Self:
141147
)
142148
return self
143149

150+
def property_type_from_name(self, name: str) -> Optional[PropertyType]:
151+
for prop in self.properties:
152+
if prop.name == name:
153+
return prop
154+
return None
155+
144156

145157
class GraphSchema(DataModel):
146158
"""This model represents the expected

tests/unit/experimental/components/test_graph_pruning.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import datetime
1718
from typing import Any, Optional
1819
from unittest.mock import ANY, Mock, patch
1920

@@ -101,6 +102,44 @@ def test_graph_pruning_filter_properties(
101102
assert filtered_properties == expected_filtered_properties
102103

103104

105+
@pytest.mark.parametrize(
106+
"properties, expected_filtered_properties",
107+
[
108+
(
109+
# all good, no bad types
110+
{
111+
"name": "John Does",
112+
"age": 25,
113+
"is_active": True,
114+
},
115+
{
116+
"name": "John Does",
117+
"age": 25,
118+
"is_active": True,
119+
},
120+
),
121+
(
122+
# map must be serialized
123+
{
124+
"age": {"dob": datetime.date(2000, 1, 1), "age_in_2025": 25},
125+
},
126+
{
127+
"age": '{"dob": "2000-01-01", "age_in_2025": 25}',
128+
},
129+
),
130+
],
131+
)
132+
def test_graph_pruning_ensure_property_type(
133+
properties: dict[str, Any],
134+
expected_filtered_properties: dict[str, Any],
135+
) -> None:
136+
pruner = GraphPruning()
137+
type_safe_properties = pruner._ensure_property_types(
138+
properties,
139+
)
140+
assert type_safe_properties == expected_filtered_properties
141+
142+
104143
@pytest.fixture(scope="module")
105144
def node_type_no_properties() -> NodeType:
106145
return NodeType(label="Person")

0 commit comments

Comments
 (0)