@@ -346,13 +346,11 @@ async def run(
346346 return graph
347347
348348 def validate_chunk (
349- self ,
350- chunk_graph : Neo4jGraph ,
351- schema : SchemaConfig
349+ self , chunk_graph : Neo4jGraph , schema : SchemaConfig
352350 ) -> Neo4jGraph :
353351 """
354- Perform validation after entity and relation extraction:
355- - Enforce schema if schema enforcement mode is on and schema is provided
352+ Perform validation after entity and relation extraction:
353+ - Enforce schema if schema enforcement mode is on and schema is provided
356354 """
357355 if self .enforce_schema != SchemaEnforcementMode .NONE :
358356 if not schema or not schema .entities : # schema is not provided
@@ -365,9 +363,9 @@ def validate_chunk(
365363 return chunk_graph
366364
367365 def _clean_graph (
368- self ,
369- graph : Neo4jGraph ,
370- schema : SchemaConfig ,
366+ self ,
367+ graph : Neo4jGraph ,
368+ schema : SchemaConfig ,
371369 ) -> Neo4jGraph :
372370 """
373371 Verify that the graph conforms to the provided schema.
@@ -389,17 +387,15 @@ def _clean_graph(
389387 return Neo4jGraph (nodes = filtered_nodes , relationships = filtered_rels )
390388
391389 def _enforce_nodes (
392- self ,
393- extracted_nodes : List [Neo4jNode ],
394- schema : SchemaConfig
390+ self , extracted_nodes : List [Neo4jNode ], schema : SchemaConfig
395391 ) -> List [Neo4jNode ]:
396392 """
397- Filter extracted nodes to be conformant to the schema.
393+ Filter extracted nodes to be conformant to the schema.
398394
399- Keep only those whose label is in schema.
400- For each valid node, filter out properties not present in the schema.
401- Remove a node if it ends up with no valid properties.
402- """
395+ Keep only those whose label is in schema.
396+ For each valid node, filter out properties not present in the schema.
397+ Remove a node if it ends up with no valid properties.
398+ """
403399 if self .enforce_schema != SchemaEnforcementMode .STRICT :
404400 return extracted_nodes
405401
@@ -424,10 +420,10 @@ def _enforce_nodes(
424420 return valid_nodes
425421
426422 def _enforce_relationships (
427- self ,
428- extracted_relationships : List [Neo4jRelationship ],
429- filtered_nodes : List [Neo4jNode ],
430- schema : SchemaConfig
423+ self ,
424+ extracted_relationships : List [Neo4jRelationship ],
425+ filtered_nodes : List [Neo4jNode ],
426+ schema : SchemaConfig ,
431427 ) -> List [Neo4jRelationship ]:
432428 """
433429 Filter extracted nodes to be conformant to the schema.
@@ -447,12 +443,16 @@ def _enforce_relationships(
447443 potential_schema = schema .potential_schema
448444
449445 for rel in extracted_relationships :
450- schema_relation = schema .relations .get (rel .type )
446+ schema_relation = (
447+ schema .relations .get (rel .type ) if schema .relations else None
448+ )
451449 if not schema_relation :
452450 continue
453451
454- if (rel .start_node_id not in valid_nodes or
455- rel .end_node_id not in valid_nodes ):
452+ if (
453+ rel .start_node_id not in valid_nodes
454+ or rel .end_node_id not in valid_nodes
455+ ):
456456 continue
457457
458458 start_label = valid_nodes [rel .start_node_id ]
@@ -461,8 +461,11 @@ def _enforce_relationships(
461461 tuple_valid = True
462462 if potential_schema :
463463 tuple_valid = (start_label , rel .type , end_label ) in potential_schema
464- reverse_tuple_valid = ((end_label , rel .type , start_label ) in
465- potential_schema )
464+ reverse_tuple_valid = (
465+ end_label ,
466+ rel .type ,
467+ start_label ,
468+ ) in potential_schema
466469
467470 if not tuple_valid and not reverse_tuple_valid :
468471 continue
@@ -483,18 +486,13 @@ def _enforce_relationships(
483486 return valid_rels
484487
485488 def _enforce_properties (
486- self ,
487- properties : Dict [str , Any ],
488- valid_properties : List [Dict [str , Any ]]
489+ self , properties : Dict [str , Any ], valid_properties : List [Dict [str , Any ]]
489490 ) -> Dict [str , Any ]:
490491 """
491492 Filter properties.
492493 Keep only those that exist in schema (i.e., valid properties).
493494 """
494495 valid_prop_names = {prop ["name" ] for prop in valid_properties }
495496 return {
496- key : value
497- for key , value in properties .items ()
498- if key in valid_prop_names
497+ key : value for key , value in properties .items () if key in valid_prop_names
499498 }
500-
0 commit comments