Skip to content

Commit 97d7527

Browse files
authored
Merge pull request #679 from FlorentinD/session-run-cypher-fixes
Reduce run_cypher calls for GDS endpoints
2 parents de7fc4b + 162e0ba commit 97d7527

File tree

8 files changed

+119
-93
lines changed

8 files changed

+119
-93
lines changed

graphdatascience/model/link_prediction_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,5 @@ def link_features(self) -> List[LinkFeature]:
4141
A list of LinkFeatures of the pipeline.
4242
4343
"""
44-
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["featureSteps"]
44+
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["featureSteps"]
4545
return [LinkFeature(s["name"], s["config"]) for s in steps]

graphdatascience/model/model.py

Lines changed: 62 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
24
from typing import Any, Dict
35

@@ -21,34 +23,32 @@ def __init__(self, name: str, query_runner: QueryRunner, server_version: ServerV
2123
def _endpoint_prefix(self) -> str:
2224
pass
2325

24-
def _list_info(self) -> DataFrame:
26+
def _list_info(self) -> Series[Any]:
27+
params = CallParameters(name=self.name())
28+
29+
result: Series[Any]
2530
if self._server_version < ServerVersion(2, 5, 0):
26-
query = "CALL gds.beta.model.list($name)"
31+
result = self._query_runner.call_procedure(
32+
"gds.beta.model.list", params=params, custom_error=False
33+
).squeeze()
2734
else:
28-
query = """
29-
CALL gds.model.list($name)
30-
YIELD
31-
modelName, modelType, modelInfo,
32-
creationTime, trainConfig, graphSchema,
33-
loaded, stored, published
34-
RETURN
35-
modelName, modelType,
36-
modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
37-
creationTime, trainConfig, graphSchema,
38-
loaded, stored, published, published AS shared
39-
"""
40-
41-
params = {"name": self.name()}
42-
43-
# FIXME use call procedure + do post processing on the client side
44-
info = self._query_runner.run_cypher(query, params, custom_error=False)
45-
46-
if len(info) == 0:
35+
result = self._query_runner.call_procedure("gds.model.list", params=params, custom_error=False).squeeze()
36+
37+
if not result.empty:
38+
# 2.5 compat format
39+
result["modelInfo"] = {
40+
**result["modelInfo"],
41+
"modelName": result["modelName"],
42+
"modelType": result["modelType"],
43+
}
44+
result["shared"] = result["published"]
45+
46+
if result.empty:
4747
raise ValueError(f"There is no '{self.name()}' in the model catalog")
4848

49-
return info
49+
return result
5050

51-
def _estimate_predict(self, predict_mode: str, graph_name: str, config: Dict[str, Any]) -> "Series[Any]":
51+
def _estimate_predict(self, predict_mode: str, graph_name: str, config: Dict[str, Any]) -> Series[Any]:
5252
endpoint = f"{self._endpoint_prefix()}{predict_mode}.estimate"
5353
config["modelName"] = self.name()
5454
params = CallParameters(graph_name=graph_name, config=config)
@@ -76,28 +76,28 @@ def type(self) -> str:
7676
The type of the model.
7777
7878
"""
79-
return self._list_info()["modelInfo"][0]["modelType"] # type: ignore
79+
return self._list_info()["modelInfo"]["modelType"] # type: ignore
8080

81-
def train_config(self) -> "Series[Any]":
81+
def train_config(self) -> Series[Any]:
8282
"""
8383
Get the train config of the model.
8484
8585
Returns:
8686
The train config of the model.
8787
8888
"""
89-
train_config: "Series[Any]" = Series(self._list_info()["trainConfig"][0])
89+
train_config: Series[Any] = Series(self._list_info()["trainConfig"])
9090
return train_config
9191

92-
def graph_schema(self) -> "Series[Any]":
92+
def graph_schema(self) -> Series[Any]:
9393
"""
9494
Get the graph schema of the model.
9595
9696
Returns:
9797
The graph schema of the model.
9898
9999
"""
100-
graph_schema: "Series[Any]" = Series(self._list_info()["graphSchema"][0])
100+
graph_schema: Series[Any] = Series(self._list_info()["graphSchema"])
101101
return graph_schema
102102

103103
def loaded(self) -> bool:
@@ -108,7 +108,7 @@ def loaded(self) -> bool:
108108
True if the model is loaded in memory, False otherwise.
109109
110110
"""
111-
return self._list_info()["loaded"].squeeze() # type: ignore
111+
return self._list_info()["loaded"] # type: ignore
112112

113113
def stored(self) -> bool:
114114
"""
@@ -118,7 +118,7 @@ def stored(self) -> bool:
118118
True if the model is stored on disk, False otherwise.
119119
120120
"""
121-
return self._list_info()["stored"].squeeze() # type: ignore
121+
return self._list_info()["stored"] # type: ignore
122122

123123
def creation_time(self) -> Any: # neo4j.time.DateTime not exported
124124
"""
@@ -128,7 +128,7 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
128128
The creation time of the model.
129129
130130
"""
131-
return self._list_info()["creationTime"].squeeze()
131+
return self._list_info()["creationTime"]
132132

133133
def shared(self) -> bool:
134134
"""
@@ -138,7 +138,7 @@ def shared(self) -> bool:
138138
True if the model is shared, False otherwise.
139139
140140
"""
141-
return self._list_info()["shared"].squeeze() # type: ignore
141+
return self._list_info()["shared"] # type: ignore
142142

143143
@compatible_with("published", min_inclusive=ServerVersion(2, 5, 0))
144144
def published(self) -> bool:
@@ -149,17 +149,17 @@ def published(self) -> bool:
149149
True if the model is published, False otherwise.
150150
151151
"""
152-
return self._list_info()["published"].squeeze() # type: ignore
152+
return self._list_info()["published"] # type: ignore
153153

154-
def model_info(self) -> "Series[Any]":
154+
def model_info(self) -> Dict[str, Any]:
155155
"""
156156
Get the model info of the model.
157157
158158
Returns:
159159
The model info of the model.
160160
161161
"""
162-
return Series(self._list_info()["modelInfo"].squeeze())
162+
return Series(self._list_info()["modelInfo"]) # type: ignore
163163

164164
def exists(self) -> bool:
165165
"""
@@ -179,7 +179,7 @@ def exists(self) -> bool:
179179
endpoint=endpoint, params=params, yields=yields, custom_error=False
180180
).squeeze()
181181

182-
def drop(self, failIfMissing: bool = False) -> "Series[Any]":
182+
def drop(self, failIfMissing: bool = False) -> Series[Any]:
183183
"""
184184
Drop the model.
185185
@@ -190,36 +190,38 @@ def drop(self, failIfMissing: bool = False) -> "Series[Any]":
190190
The result of the drop operation.
191191
192192
"""
193+
params = CallParameters(model_name=self._name, fail_if_missing=failIfMissing)
193194
if self._server_version < ServerVersion(2, 5, 0):
194-
query = "CALL gds.beta.model.drop($model_name, $fail_if_missing)"
195+
return self._query_runner.call_procedure( # type: ignore
196+
"gds.beta.model.drop", params=params, custom_error=False
197+
).squeeze()
195198
else:
196-
query = """
197-
CALL gds.model.drop($model_name, $fail_if_missing)
198-
YIELD
199-
modelName, modelType, modelInfo,
200-
creationTime, trainConfig, graphSchema,
201-
loaded, stored, published
202-
RETURN
203-
modelName, modelType,
204-
modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
205-
creationTime, trainConfig, graphSchema,
206-
loaded, stored, published, published AS shared
207-
"""
208-
209-
params = {"model_name": self._name, "fail_if_missing": failIfMissing}
210-
# FIXME use call procedure + do post processing on the client side
211-
return self._query_runner.run_cypher(query, params, custom_error=False).squeeze() # type: ignore
212-
213-
def metrics(self) -> "Series[Any]":
199+
result: Series[Any] = self._query_runner.call_procedure(
200+
"gds.model.drop", params=params, custom_error=False
201+
).squeeze()
202+
203+
if result.empty:
204+
return Series()
205+
206+
# modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo
207+
result["modelInfo"] = {
208+
**result["modelInfo"],
209+
"modelName": result["modelName"],
210+
"modelType": result["modelType"],
211+
}
212+
result["shared"] = result["published"]
213+
return result
214+
215+
def metrics(self) -> Series[Any]:
214216
"""
215217
Get the metrics of the model.
216218
217219
Returns:
218220
The metrics of the model.
219221
220222
"""
221-
model_info = self._list_info()["modelInfo"][0]
222-
metrics: "Series[Any]" = Series(model_info["metrics"])
223+
model_info = self._list_info()["modelInfo"]
224+
metrics: Series[Any] = Series(model_info["metrics"])
223225
return metrics
224226

225227
@graph_type_check
@@ -242,7 +244,7 @@ def predict_stream(self, G: Graph, **config: Any) -> DataFrame:
242244
return self._query_runner.call_procedure(endpoint=endpoint, params=params, logging=True)
243245

244246
@graph_type_check
245-
def predict_stream_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
247+
def predict_stream_estimate(self, G: Graph, **config: Any) -> Series[Any]:
246248
"""
247249
Estimate the prediction on the given graph using the model and stream the results as DataFrame
248250
@@ -257,7 +259,7 @@ def predict_stream_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
257259
return self._estimate_predict("stream", G.name(), config)
258260

259261
@graph_type_check
260-
def predict_mutate(self, G: Graph, **config: Any) -> "Series[Any]":
262+
def predict_mutate(self, G: Graph, **config: Any) -> Series[Any]:
261263
"""
262264
Predict on the given graph using the model and mutate the graph with the results.
263265
@@ -278,7 +280,7 @@ def predict_mutate(self, G: Graph, **config: Any) -> "Series[Any]":
278280
).squeeze()
279281

280282
@graph_type_check
281-
def predict_mutate_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
283+
def predict_mutate_estimate(self, G: Graph, **config: Any) -> Series[Any]:
282284
"""
283285
Estimate the memory needed to predict on the given graph using the model.
284286

graphdatascience/model/model_proc_runner.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Dict, List, Optional, Tuple
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, List, Optional, Tuple, Union
24

35
from pandas import DataFrame, Series
46

@@ -49,23 +51,29 @@ def create(
4951
class ModelProcRunner(ModelResolver):
5052
@client_only_endpoint("gds.model")
5153
def get(self, model_name: str) -> Model:
54+
params = CallParameters(model_name=model_name)
5255
if self._server_version < ServerVersion(2, 5, 0):
53-
query = "CALL gds.beta.model.list($model_name) YIELD modelInfo RETURN modelInfo.modelType AS modelType"
56+
endpoint = "gds.beta.model.list"
57+
yields = ["modelInfo"]
58+
result_25: Series[Any] = self._query_runner.call_procedure(
59+
endpoint=endpoint, params=params, yields=yields, custom_error=False
60+
).squeeze()
61+
model_type = str(result_25["modelInfo"]["modelType"]) if not result_25.empty else None
5462
else:
55-
query = "CALL gds.model.list($model_name) YIELD modelType"
56-
57-
params = {"model_name": model_name}
58-
# FIXME use call procedure + do post processing on the client side
59-
result = self._query_runner.run_cypher(query, params, custom_error=False)
60-
61-
if len(result) == 0:
63+
endpoint = "gds.model.list"
64+
yields = ["modelType"]
65+
result: Union[str, Series[Any]] = self._query_runner.call_procedure(
66+
endpoint=endpoint, params=params, yields=yields, custom_error=False
67+
).squeeze()
68+
model_type = result if isinstance(result, str) else None
69+
70+
if model_type is None:
6271
raise ValueError(f"No loaded model named '{model_name}' exists")
6372

64-
model_type = str(result["modelType"].squeeze())
6573
return self._resolve_model(model_type, model_name)
6674

6775
@compatible_with("store", min_inclusive=ServerVersion(2, 5, 0))
68-
def store(self, model: Model, failIfUnsupportedType: bool = True) -> "Series[Any]":
76+
def store(self, model: Model, failIfUnsupportedType: bool = True) -> Series[Any]:
6977
self._namespace += ".store"
7078
params = CallParameters(model_name=model.name(), fail_flag=failIfUnsupportedType)
7179

@@ -88,7 +96,7 @@ def publish(self, model: Model) -> Model:
8896
return self._resolve_model(model_type, model_name)
8997

9098
@compatible_with("load", min_inclusive=ServerVersion(2, 5, 0))
91-
def load(self, model_name: str) -> Tuple[Model, "Series[Any]"]:
99+
def load(self, model_name: str) -> Tuple[Model, Series[Any]]:
92100
self._namespace += ".load"
93101

94102
params = CallParameters(model_name=model_name)
@@ -101,7 +109,7 @@ def load(self, model_name: str) -> Tuple[Model, "Series[Any]"]:
101109
return proc_runner.get(result["modelName"]), result
102110

103111
@compatible_with("delete", min_inclusive=ServerVersion(2, 5, 0))
104-
def delete(self, model: Model) -> "Series[Any]":
112+
def delete(self, model: Model) -> Series[Any]:
105113
self._namespace += ".delete"
106114
params = CallParameters(model_name=model.name())
107115
return self._query_runner.call_procedure(endpoint=self._namespace, params=params).squeeze() # type: ignore
@@ -117,15 +125,15 @@ def list(self, model: Optional[Model] = None) -> DataFrame:
117125
return self._query_runner.call_procedure(endpoint=self._namespace, params=params)
118126

119127
@compatible_with("exists", min_inclusive=ServerVersion(2, 5, 0))
120-
def exists(self, model_name: str) -> "Series[Any]":
128+
def exists(self, model_name: str) -> Series[Any]:
121129
self._namespace += ".exists"
122130

123131
return self._query_runner.call_procedure( # type: ignore
124132
endpoint=self._namespace, params=CallParameters(model_name=model_name)
125133
).squeeze()
126134

127135
@compatible_with("drop", min_inclusive=ServerVersion(2, 5, 0))
128-
def drop(self, model: Model) -> "Series[Any]":
136+
def drop(self, model: Model) -> Series[Any]:
129137
self._namespace += ".drop"
130138

131139
return self._query_runner.call_procedure( # type: ignore

graphdatascience/model/node_classification_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def classes(self) -> List[int]:
6262
The classes of the model.
6363
6464
"""
65-
return self._list_info()["modelInfo"][0]["classes"] # type: ignore
65+
return self._list_info()["modelInfo"]["classes"] # type: ignore
6666

6767
def feature_properties(self) -> List[str]:
6868
"""
@@ -72,5 +72,5 @@ def feature_properties(self) -> List[str]:
7272
The feature properties of the model.
7373
7474
"""
75-
features: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["featureProperties"]
75+
features: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["featureProperties"]
7676
return [f["feature"] for f in features]

graphdatascience/model/node_regression_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ def feature_properties(self) -> List[str]:
2121
The feature properties of the model.
2222
2323
"""
24-
features: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["featureProperties"]
24+
features: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["featureProperties"]
2525
return [f["feature"] for f in features]

graphdatascience/model/pipeline_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def best_parameters(self) -> "Series[Any]":
7575
The best parameters for the pipeline model.
7676
7777
"""
78-
best_params: Dict[str, Any] = self._list_info()["modelInfo"][0]["bestParameters"]
78+
best_params: Dict[str, Any] = self._list_info()["modelInfo"]["bestParameters"]
7979
return Series(best_params)
8080

8181
def node_property_steps(self) -> List[NodePropertyStep]:
@@ -86,7 +86,7 @@ def node_property_steps(self) -> List[NodePropertyStep]:
8686
The node property steps for the pipeline model.
8787
8888
"""
89-
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["nodePropertySteps"]
89+
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["nodePropertySteps"]
9090
return [NodePropertyStep(s["name"], s["config"]) for s in steps]
9191

9292
def metrics(self) -> "Series[Any]":
@@ -97,6 +97,6 @@ def metrics(self) -> "Series[Any]":
9797
The metrics for the pipeline model.
9898
9999
"""
100-
model_metrics: Dict[str, Any] = self._list_info()["modelInfo"][0]["metrics"]
100+
model_metrics: Dict[str, Any] = self._list_info()["modelInfo"]["metrics"]
101101
metric_scores: Dict[str, MetricScores] = {k: MetricScores.create(v) for k, v in (model_metrics.items())}
102102
return Series(metric_scores)

0 commit comments

Comments
 (0)