Skip to content

Commit b93c117

Browse files
committed
Use pydantic to parse cypher results into result classes
1 parent a1def51 commit b93c117

File tree

3 files changed

+24
-10
lines changed

3 files changed

+24
-10
lines changed

graphdatascience/procedure_surface/api/estimation_result.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
from dataclasses import dataclass
1+
from __future__ import annotations
2+
23
from typing import Any
34

5+
from pydantic import BaseModel, ConfigDict
6+
from pydantic.alias_generators import to_camel
7+
8+
9+
class EstimationResult(BaseModel):
10+
model_config = ConfigDict(alias_generator=to_camel)
411

5-
@dataclass(frozen=True, repr=True)
6-
class EstimationResult:
712
node_count: int
813
relationship_count: int
914
required_memory: str
@@ -16,3 +21,7 @@ class EstimationResult:
1621

1722
def __getitem__(self, item: str) -> Any:
1823
return getattr(self, item)
24+
25+
@staticmethod
26+
def from_cypher(cypher_result: dict[str, Any]) -> EstimationResult:
27+
return EstimationResult(**cypher_result)

graphdatascience/procedure_surface/api/wcc_endpoints.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from dataclasses import dataclass
54
from typing import Any, List, Optional
65

76
from pandas import DataFrame
7+
from pydantic import BaseModel, ConfigDict
8+
from pydantic.alias_generators import to_camel
89

910
from ...graph.graph_object import Graph
1011
from .estimation_result import EstimationResult
@@ -272,8 +273,9 @@ def estimate(
272273
pass
273274

274275

275-
@dataclass(frozen=True, repr=True)
276-
class WccMutateResult:
276+
class WccMutateResult(BaseModel):
277+
model_config = ConfigDict(alias_generator=to_camel)
278+
277279
component_count: int
278280
component_distribution: dict[str, Any]
279281
pre_processing_millis: int
@@ -287,8 +289,9 @@ def __getitem__(self, item: str) -> Any:
287289
return getattr(self, item)
288290

289291

290-
@dataclass(frozen=True, repr=True)
291-
class WccStatsResult:
292+
class WccStatsResult(BaseModel):
293+
model_config = ConfigDict(alias_generator=to_camel)
294+
292295
component_count: int
293296
component_distribution: dict[str, Any]
294297
pre_processing_millis: int
@@ -300,8 +303,9 @@ def __getitem__(self, item: str) -> Any:
300303
return getattr(self, item)
301304

302305

303-
@dataclass(frozen=True, repr=True)
304-
class WccWriteResult:
306+
class WccWriteResult(BaseModel):
307+
model_config = ConfigDict(alias_generator=to_camel)
308+
305309
component_count: int
306310
component_distribution: dict[str, Any]
307311
pre_processing_millis: int

requirements/base/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ tqdm >= 4.0, < 5.0
88
typing-extensions >= 4.0, < 5.0
99
requests
1010
tenacity >= 9.0
11+
pydantic >= 2.11

0 commit comments

Comments
 (0)