|
19 | 19 | import enum |
20 | 20 | import json |
21 | 21 | import logging |
22 | | -import re |
23 | 22 | from datetime import datetime |
24 | | -from typing import Any, List, Optional, Union |
| 23 | +from typing import Any, List, Optional, Union, cast |
| 24 | + |
| 25 | +import json_repair |
25 | 26 |
|
26 | 27 | from pydantic import ValidationError, validate_call |
27 | 28 |
|
|
36 | 37 | TextChunks, |
37 | 38 | ) |
38 | 39 | from neo4j_graphrag.experimental.pipeline.component import Component |
| 40 | +from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError |
39 | 41 | from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate |
40 | 42 | from neo4j_graphrag.llm import LLMInterface |
41 | 43 |
|
@@ -100,28 +102,15 @@ def balance_curly_braces(json_string: str) -> str: |
100 | 102 | return "".join(fixed_json) |
101 | 103 |
|
102 | 104 |
|
103 | | -def fix_invalid_json(invalid_json_string: str) -> str: |
104 | | - # Fix missing quotes around field names |
105 | | - invalid_json_string = re.sub( |
106 | | - r"([{,]\s*)(\w+)(\s*:)", r'\1"\2"\3', invalid_json_string |
107 | | - ) |
108 | | - |
109 | | - # Fix missing quotes around string values, correctly ignoring null, true, false, and numeric values |
110 | | - invalid_json_string = re.sub( |
111 | | - r"(?<=:\s)(?!(null|true|false|\d+\.?\d*))([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=[,}])", |
112 | | - r'"\2"', |
113 | | - invalid_json_string, |
114 | | - ) |
115 | | - |
116 | | - # Correct the specific issue: remove trailing commas within arrays or objects before closing braces or brackets |
117 | | - invalid_json_string = re.sub(r",\s*(?=[}\]])", "", invalid_json_string) |
| 105 | +def fix_invalid_json(raw_json: str) -> str: |
| 106 | + repaired_json = json_repair.repair_json(raw_json) |
| 107 | + repaired_json = cast(str, repaired_json).strip() |
118 | 108 |
|
119 | | - # Normalize excessive curly braces |
120 | | - invalid_json_string = re.sub(r"{{+", "{", invalid_json_string) |
121 | | - invalid_json_string = re.sub(r"}}+", "}", invalid_json_string) |
122 | | - |
123 | | - # Balance curly braces |
124 | | - return balance_curly_braces(invalid_json_string) |
| 109 | + if repaired_json == '""': |
| 110 | + raise InvalidJSONError("JSON repair resulted in an empty or invalid JSON.") |
| 111 | + if not repaired_json: |
| 112 | + raise InvalidJSONError("JSON repair resulted in an empty string.") |
| 113 | + return repaired_json |
125 | 114 |
|
126 | 115 |
|
127 | 116 | class EntityRelationExtractor(Component, abc.ABC): |
@@ -223,24 +212,18 @@ async def extract_for_chunk( |
223 | 212 | ) |
224 | 213 | llm_result = await self.llm.ainvoke(prompt) |
225 | 214 | try: |
226 | | - result = json.loads(llm_result.content) |
227 | | - except json.JSONDecodeError: |
228 | | - logger.info( |
229 | | - f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}. Trying to fix it." |
230 | | - ) |
231 | | - fixed_content = fix_invalid_json(llm_result.content) |
232 | | - try: |
233 | | - result = json.loads(fixed_content) |
234 | | - except json.JSONDecodeError as e: |
235 | | - if self.on_error == OnError.RAISE: |
236 | | - raise LLMGenerationError( |
237 | | - f"LLM response is not valid JSON {fixed_content}: {e}" |
238 | | - ) |
239 | | - else: |
240 | | - logger.error( |
241 | | - f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}" |
242 | | - ) |
243 | | - result = {"nodes": [], "relationships": []} |
| 215 | + llm_generated_json = fix_invalid_json(llm_result.content) |
| 216 | + result = json.loads(llm_generated_json) |
| 217 | + except (json.JSONDecodeError, InvalidJSONError) as e: |
| 218 | + if self.on_error == OnError.RAISE: |
| 219 | + raise LLMGenerationError( |
| 220 | + f"LLM response is not valid JSON {llm_result.content}: {e}" |
| 221 | + ) |
| 222 | + else: |
| 223 | + logger.error( |
| 224 | + f"LLM response is not valid JSON {llm_result.content} for chunk_index={chunk.index}" |
| 225 | + ) |
| 226 | + result = {"nodes": [], "relationships": []} |
244 | 227 | try: |
245 | 228 | chunk_graph = Neo4jGraph(**result) |
246 | 229 | except ValidationError as e: |
|
0 commit comments