Skip to content

Commit 1c8c99b

Browse files
committed
Reduce run_cypher calls for GDS endpoints
1 parent de7fc4b commit 1c8c99b

File tree

4 files changed

+72
-55
lines changed

4 files changed

+72
-55
lines changed

graphdatascience/model/model.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
from __future__ import annotations
2+
13
from abc import ABC, abstractmethod
2-
from typing import Any, Dict
4+
from typing import Any, Dict, Optional
35

46
from pandas import DataFrame, Series
57

@@ -41,14 +43,15 @@ def _list_info(self) -> DataFrame:
4143
params = {"name": self.name()}
4244

4345
# FIXME use call procedure + do post processing on the client side
46+
# TODO really fixme
4447
info = self._query_runner.run_cypher(query, params, custom_error=False)
4548

4649
if len(info) == 0:
4750
raise ValueError(f"There is no '{self.name()}' in the model catalog")
4851

4952
return info
5053

51-
def _estimate_predict(self, predict_mode: str, graph_name: str, config: Dict[str, Any]) -> "Series[Any]":
54+
def _estimate_predict(self, predict_mode: str, graph_name: str, config: Dict[str, Any]) -> Series[Any]:
5255
endpoint = f"{self._endpoint_prefix()}{predict_mode}.estimate"
5356
config["modelName"] = self.name()
5457
params = CallParameters(graph_name=graph_name, config=config)
@@ -78,26 +81,26 @@ def type(self) -> str:
7881
"""
7982
return self._list_info()["modelInfo"][0]["modelType"] # type: ignore
8083

81-
def train_config(self) -> "Series[Any]":
84+
def train_config(self) -> Series[Any]:
8285
"""
8386
Get the train config of the model.
8487
8588
Returns:
8689
The train config of the model.
8790
8891
"""
89-
train_config: "Series[Any]" = Series(self._list_info()["trainConfig"][0])
92+
train_config: Series[Any] = Series(self._list_info()["trainConfig"][0])
9093
return train_config
9194

92-
def graph_schema(self) -> "Series[Any]":
95+
def graph_schema(self) -> Series[Any]:
9396
"""
9497
Get the graph schema of the model.
9598
9699
Returns:
97100
The graph schema of the model.
98101
99102
"""
100-
graph_schema: "Series[Any]" = Series(self._list_info()["graphSchema"][0])
103+
graph_schema: Series[Any] = Series(self._list_info()["graphSchema"][0])
101104
return graph_schema
102105

103106
def loaded(self) -> bool:
@@ -151,7 +154,7 @@ def published(self) -> bool:
151154
"""
152155
return self._list_info()["published"].squeeze() # type: ignore
153156

154-
def model_info(self) -> "Series[Any]":
157+
def model_info(self) -> Series[Any]:
155158
"""
156159
Get the model info of the model.
157160
@@ -179,7 +182,7 @@ def exists(self) -> bool:
179182
endpoint=endpoint, params=params, yields=yields, custom_error=False
180183
).squeeze()
181184

182-
def drop(self, failIfMissing: bool = False) -> "Series[Any]":
185+
def drop(self, failIfMissing: bool = False) -> Series[Any]:
183186
"""
184187
Drop the model.
185188
@@ -190,27 +193,29 @@ def drop(self, failIfMissing: bool = False) -> "Series[Any]":
190193
The result of the drop operation.
191194
192195
"""
196+
params = CallParameters(model_name=self._name, fail_if_missing=failIfMissing)
193197
if self._server_version < ServerVersion(2, 5, 0):
194-
query = "CALL gds.beta.model.drop($model_name, $fail_if_missing)"
198+
return self._query_runner.call_procedure( # type: ignore
199+
"gds.beta.model.drop", params=params, custom_error=False
200+
).squeeze()
195201
else:
196-
query = """
197-
CALL gds.model.drop($model_name, $fail_if_missing)
198-
YIELD
199-
modelName, modelType, modelInfo,
200-
creationTime, trainConfig, graphSchema,
201-
loaded, stored, published
202-
RETURN
203-
modelName, modelType,
204-
modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
205-
creationTime, trainConfig, graphSchema,
206-
loaded, stored, published, published AS shared
207-
"""
202+
result: Optional[Series[Any]] = self._query_runner.call_procedure(
203+
"gds.model.drop", params=params, custom_error=False
204+
).squeeze()
208205

209-
params = {"model_name": self._name, "fail_if_missing": failIfMissing}
210-
# FIXME use call procedure + do post processing on the client side
211-
return self._query_runner.run_cypher(query, params, custom_error=False).squeeze() # type: ignore
206+
if result is None:
207+
return Series()
208+
209+
# modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo
210+
result["modelInfo"] = {
211+
**result["modelInfo"],
212+
"modelName": result["modelName"],
213+
"modelType": result["modelType"],
214+
}
215+
result["shared"] = result["published"]
216+
return result
212217

213-
def metrics(self) -> "Series[Any]":
218+
def metrics(self) -> Series[Any]:
214219
"""
215220
Get the metrics of the model.
216221
@@ -219,7 +224,7 @@ def metrics(self) -> "Series[Any]":
219224
220225
"""
221226
model_info = self._list_info()["modelInfo"][0]
222-
metrics: "Series[Any]" = Series(model_info["metrics"])
227+
metrics: Series[Any] = Series(model_info["metrics"])
223228
return metrics
224229

225230
@graph_type_check
@@ -242,7 +247,7 @@ def predict_stream(self, G: Graph, **config: Any) -> DataFrame:
242247
return self._query_runner.call_procedure(endpoint=endpoint, params=params, logging=True)
243248

244249
@graph_type_check
245-
def predict_stream_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
250+
def predict_stream_estimate(self, G: Graph, **config: Any) -> Series[Any]:
246251
"""
247252
Estimate the prediction on the given graph using the model and stream the results as DataFrame
248253
@@ -257,7 +262,7 @@ def predict_stream_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
257262
return self._estimate_predict("stream", G.name(), config)
258263

259264
@graph_type_check
260-
def predict_mutate(self, G: Graph, **config: Any) -> "Series[Any]":
265+
def predict_mutate(self, G: Graph, **config: Any) -> Series[Any]:
261266
"""
262267
Predict on the given graph using the model and mutate the graph with the results.
263268
@@ -278,7 +283,7 @@ def predict_mutate(self, G: Graph, **config: Any) -> "Series[Any]":
278283
).squeeze()
279284

280285
@graph_type_check
281-
def predict_mutate_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
286+
def predict_mutate_estimate(self, G: Graph, **config: Any) -> Series[Any]:
282287
"""
283288
Estimate the memory needed to predict on the given graph using the model.
284289

graphdatascience/model/model_proc_runner.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Dict, List, Optional, Tuple
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, List, Optional, Tuple, Union
24

35
from pandas import DataFrame, Series
46

@@ -49,23 +51,29 @@ def create(
4951
class ModelProcRunner(ModelResolver):
5052
@client_only_endpoint("gds.model")
5153
def get(self, model_name: str) -> Model:
54+
params = CallParameters(model_name=model_name)
5255
if self._server_version < ServerVersion(2, 5, 0):
53-
query = "CALL gds.beta.model.list($model_name) YIELD modelInfo RETURN modelInfo.modelType AS modelType"
56+
endpoint = "gds.beta.model.list"
57+
yields = ["modelInfo"]
58+
result: Series[Any] = self._query_runner.call_procedure(
59+
endpoint=endpoint, params=params, yields=yields, custom_error=False
60+
).squeeze()
61+
model_type = str(result["modelInfo"]["modelType"]) if not result.empty else None
5462
else:
55-
query = "CALL gds.model.list($model_name) YIELD modelType"
56-
57-
params = {"model_name": model_name}
58-
# FIXME use call procedure + do post processing on the client side
59-
result = self._query_runner.run_cypher(query, params, custom_error=False)
60-
61-
if len(result) == 0:
63+
endpoint = "gds.model.list"
64+
yields = ["modelType"]
65+
result: Union[str, Series[Any]] = self._query_runner.call_procedure(
66+
endpoint=endpoint, params=params, yields=yields, custom_error=False
67+
).squeeze()
68+
model_type = result if isinstance(result, str) else None
69+
70+
if model_type is None:
6271
raise ValueError(f"No loaded model named '{model_name}' exists")
6372

64-
model_type = str(result["modelType"].squeeze())
6573
return self._resolve_model(model_type, model_name)
6674

6775
@compatible_with("store", min_inclusive=ServerVersion(2, 5, 0))
68-
def store(self, model: Model, failIfUnsupportedType: bool = True) -> "Series[Any]":
76+
def store(self, model: Model, failIfUnsupportedType: bool = True) -> Series[Any]:
6977
self._namespace += ".store"
7078
params = CallParameters(model_name=model.name(), fail_flag=failIfUnsupportedType)
7179

@@ -88,7 +96,7 @@ def publish(self, model: Model) -> Model:
8896
return self._resolve_model(model_type, model_name)
8997

9098
@compatible_with("load", min_inclusive=ServerVersion(2, 5, 0))
91-
def load(self, model_name: str) -> Tuple[Model, "Series[Any]"]:
99+
def load(self, model_name: str) -> Tuple[Model, Series[Any]]:
92100
self._namespace += ".load"
93101

94102
params = CallParameters(model_name=model_name)
@@ -101,7 +109,7 @@ def load(self, model_name: str) -> Tuple[Model, "Series[Any]"]:
101109
return proc_runner.get(result["modelName"]), result
102110

103111
@compatible_with("delete", min_inclusive=ServerVersion(2, 5, 0))
104-
def delete(self, model: Model) -> "Series[Any]":
112+
def delete(self, model: Model) -> Series[Any]:
105113
self._namespace += ".delete"
106114
params = CallParameters(model_name=model.name())
107115
return self._query_runner.call_procedure(endpoint=self._namespace, params=params).squeeze() # type: ignore
@@ -117,15 +125,15 @@ def list(self, model: Optional[Model] = None) -> DataFrame:
117125
return self._query_runner.call_procedure(endpoint=self._namespace, params=params)
118126

119127
@compatible_with("exists", min_inclusive=ServerVersion(2, 5, 0))
120-
def exists(self, model_name: str) -> "Series[Any]":
128+
def exists(self, model_name: str) -> Series[Any]:
121129
self._namespace += ".exists"
122130

123131
return self._query_runner.call_procedure( # type: ignore
124132
endpoint=self._namespace, params=CallParameters(model_name=model_name)
125133
).squeeze()
126134

127135
@compatible_with("drop", min_inclusive=ServerVersion(2, 5, 0))
128-
def drop(self, model: Model) -> "Series[Any]":
136+
def drop(self, model: Model) -> Series[Any]:
129137
self._namespace += ".drop"
130138

131139
return self._query_runner.call_procedure( # type: ignore

graphdatascience/tests/integration/test_model_object.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def gs_model(gds: GraphDataScience, G: Graph, runner: Neo4jQueryRunner) -> Gener
4444
yield model
4545

4646
namespace = "beta." if gds.server_version() < ServerVersion(2, 5, 0) else ""
47-
query = f"CALL gds.{namespace}model.drop($name)"
47+
query = f"CALL gds.{namespace}model.drop($name, false)"
4848
params = {"name": model.name()}
4949
runner.run_cypher(query, params)
5050

@@ -53,25 +53,24 @@ def test_model_exists(gs_model: GraphSageModel) -> None:
5353
assert gs_model.exists()
5454

5555

56-
def test_model_drop(gds: GraphDataScience, G: Graph) -> None:
57-
model, _ = gds.beta.graphSage.train(G, modelName="gs-model", featureProperties=["age"])
56+
def test_model_drop(gds: GraphDataScience, G: Graph, gs_model: GraphSageModel) -> None:
57+
model_type = gs_model.type()
58+
model_published = gs_model.shared()
5859

59-
model_type = model.type()
60-
model_published = model.shared()
61-
drop_result = model.drop()
60+
drop_result = gs_model.drop()
6261
if gds.server_version() >= ServerVersion(2, 5, 0):
63-
assert drop_result["modelName"] == model.name()
62+
assert drop_result["modelName"] == gs_model.name()
6463
assert drop_result["modelType"] == model_type
6564
assert drop_result["published"] == model_published
66-
assert drop_result["modelInfo"]["modelName"] == model.name()
65+
assert drop_result["modelInfo"]["modelName"] == gs_model.name()
6766

68-
assert not model.exists()
67+
assert not gs_model.exists()
6968

7069
# Should not raise error.
71-
model.drop(failIfMissing=False)
70+
gs_model.drop(failIfMissing=False)
7271

7372
with pytest.raises(Exception):
74-
model.drop(failIfMissing=True)
73+
gs_model.drop(failIfMissing=True)
7574

7675

7776
def test_model_name(gs_model: GraphSageModel) -> None:

graphdatascience/tests/integration/test_model_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ def test_model_get_graphsage(gds: GraphDataScience, gs_model: GraphSageModel) ->
300300
model.drop()
301301

302302

303+
def test_model_get_no_model(gds: GraphDataScience) -> None:
304+
with pytest.raises(ValueError, match="No loaded model named 'no_model' exists"):
305+
gds.model.get("no_model")
306+
307+
303308
@pytest.mark.model_store_location
304309
@pytest.mark.filterwarnings("ignore: The query used a deprecated procedure.")
305310
def test_deprecated_model_Functions_still_work(gds: GraphDataScience, gs_model: GraphSageModel) -> None:

0 commit comments

Comments
 (0)