|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import json |
18 | | -import yaml |
19 | 18 | import logging |
| 19 | +import warnings |
20 | 20 | from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence |
21 | 21 | from pathlib import Path |
22 | 22 |
|
|
42 | 42 | ) |
43 | 43 | from neo4j_graphrag.generation import SchemaExtractionTemplate, PromptTemplate |
44 | 44 | from neo4j_graphrag.llm import LLMInterface |
| 45 | +from neo4j_graphrag.utils.file_handler import FileHandler, FileFormat |
45 | 46 |
|
46 | 47 |
|
47 | 48 | class PropertyType(BaseModel): |
@@ -157,101 +158,68 @@ def node_type_from_label(self, label: str) -> Optional[NodeType]: |
157 | 158 | def relationship_type_from_label(self, label: str) -> Optional[RelationshipType]: |
158 | 159 | return self._relationship_type_index.get(label) |
159 | 160 |
|
160 | | - def store_as_json(self, file_path: str) -> None: |
| 161 | + def save( |
| 162 | + self, |
| 163 | + file_path: Union[str, Path], |
| 164 | + overwrite: bool = False, |
| 165 | + format: Optional[FileFormat] = None, |
| 166 | + ) -> None: |
161 | 167 | """ |
162 | | - Save the schema configuration to a JSON file. |
| 168 | + Save the schema configuration to file. |
163 | 169 |
|
164 | 170 | Args: |
165 | 171 | file_path (str): The path where the schema configuration will be saved. |
| 172 | + overwrite (bool): If set to True, existing file will be overwritten. Default to False. |
| 173 | + format (Optional[FileFormat]): The file format to save the schema configuration into. By default, it is inferred from file_path extension. |
166 | 174 | """ |
167 | | - with open(file_path, "w") as f: |
168 | | - json.dump(self.model_dump(), f, indent=2) |
| 175 | + data = self.model_dump(mode="json") |
| 176 | + file_handler = FileHandler() |
| 177 | + file_handler.write(data, file_path, overwrite=overwrite, format=format) |
169 | 178 |
|
170 | | - def store_as_yaml(self, file_path: str) -> None: |
171 | | - """ |
172 | | - Save the schema configuration to a YAML file. |
| 179 | + def store_as_json( |
| 180 | + self, file_path: Union[str, Path], overwrite: bool = False |
| 181 | + ) -> None: |
| 182 | + warnings.warn( |
| 183 | + "Use .save(..., format=FileFormat.JSON) instead.", DeprecationWarning |
| 184 | + ) |
| 185 | + return self.save(file_path, overwrite=overwrite, format=FileFormat.JSON) |
173 | 186 |
|
174 | | - Args: |
175 | | - file_path (str): The path where the schema configuration will be saved. |
176 | | - """ |
177 | | - # create a copy of the data and convert tuples to lists for YAML compatibility |
178 | | - data = self.model_dump() |
179 | | - if data.get("node_types"): |
180 | | - data["node_types"] = list(data["node_types"]) |
181 | | - if data.get("relationship_types"): |
182 | | - data["relationship_types"] = list(data["relationship_types"]) |
183 | | - if data.get("patterns"): |
184 | | - data["patterns"] = [list(item) for item in data["patterns"]] |
185 | | - |
186 | | - with open(file_path, "w") as f: |
187 | | - yaml.dump(data, f, default_flow_style=False, sort_keys=False) |
| 187 | + def store_as_yaml( |
| 188 | + self, file_path: Union[str, Path], overwrite: bool = False |
| 189 | + ) -> None: |
| 190 | + warnings.warn( |
| 191 | + "Use .save(..., format=FileFormat.YAML) instead.", DeprecationWarning |
| 192 | + ) |
| 193 | + return self.save(file_path, overwrite=overwrite, format=FileFormat.YAML) |
188 | 194 |
|
189 | 195 | @classmethod |
190 | | - def from_file(cls, file_path: Union[str, Path]) -> Self: |
| 196 | + def from_file( |
| 197 | + cls, file_path: Union[str, Path], format: Optional[FileFormat] = None |
| 198 | + ) -> Self: |
191 | 199 | """ |
192 | 200 | Load a schema configuration from a file (either JSON or YAML). |
193 | 201 |
|
194 | | - The file format is automatically detected based on the file extension. |
| 202 | + The file format is automatically detected based on the file extension, |
| 203 | + unless the format parameter is set. |
195 | 204 |
|
196 | 205 | Args: |
197 | 206 | file_path (Union[str, Path]): The path to the schema configuration file. |
| 207 | + format (Optional[FileFormat]): The format of the schema configuration file (json or yaml). |
198 | 208 |
|
199 | 209 | Returns: |
200 | 210 | GraphSchema: The loaded schema configuration. |
201 | 211 | """ |
202 | 212 | file_path = Path(file_path) |
| 213 | + file_handler = FileHandler() |
| 214 | + try: |
| 215 | + data = file_handler.read(file_path, format=format) |
| 216 | + except ValueError: |
| 217 | + raise |
203 | 218 |
|
204 | | - if not file_path.exists(): |
205 | | - raise FileNotFoundError(f"Schema file not found: {file_path}") |
206 | | - |
207 | | - if file_path.suffix.lower() in [".json"]: |
208 | | - return cls.from_json(file_path) |
209 | | - elif file_path.suffix.lower() in [".yaml", ".yml"]: |
210 | | - return cls.from_yaml(file_path) |
211 | | - else: |
212 | | - raise ValueError( |
213 | | - f"Unsupported file format: {file_path.suffix}. Use .json, .yaml, or .yml" |
214 | | - ) |
215 | | - |
216 | | - @classmethod |
217 | | - def from_json(cls, file_path: Union[str, Path]) -> Self: |
218 | | - """ |
219 | | - Load a schema configuration from a JSON file. |
220 | | -
|
221 | | - Args: |
222 | | - file_path (Union[str, Path]): The path to the JSON schema configuration file. |
223 | | -
|
224 | | - Returns: |
225 | | - GraphSchema: The loaded schema configuration. |
226 | | - """ |
227 | | - with open(file_path, "r") as f: |
228 | | - try: |
229 | | - data = json.load(f) |
230 | | - return cls.model_validate(data) |
231 | | - except json.JSONDecodeError as e: |
232 | | - raise ValueError(f"Invalid JSON file: {e}") |
233 | | - except ValidationError as e: |
234 | | - raise SchemaValidationError(f"Schema validation failed: {e}") |
235 | | - |
236 | | - @classmethod |
237 | | - def from_yaml(cls, file_path: Union[str, Path]) -> Self: |
238 | | - """ |
239 | | - Load a schema configuration from a YAML file. |
240 | | -
|
241 | | - Args: |
242 | | - file_path (Union[str, Path]): The path to the YAML schema configuration file. |
243 | | -
|
244 | | - Returns: |
245 | | - GraphSchema: The loaded schema configuration. |
246 | | - """ |
247 | | - with open(file_path, "r") as f: |
248 | | - try: |
249 | | - data = yaml.safe_load(f) |
250 | | - return cls.model_validate(data) |
251 | | - except yaml.YAMLError as e: |
252 | | - raise ValueError(f"Invalid YAML file: {e}") |
253 | | - except ValidationError as e: |
254 | | - raise SchemaValidationError(f"Schema validation failed: {e}") |
| 219 | + try: |
| 220 | + return cls.model_validate(data) |
| 221 | + except ValidationError as e: |
| 222 | + raise SchemaValidationError(str(e)) from e |
255 | 223 |
|
256 | 224 |
|
257 | 225 | class SchemaBuilder(Component): |
|
0 commit comments