Skip to content

Commit 162e0ba

Browse files
committed
Rewrite model._model_info to use call_endpoint
1 parent 5eddbe6 commit 162e0ba

File tree

6 files changed

+53
-44
lines changed

6 files changed

+53
-44
lines changed

graphdatascience/model/link_prediction_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,5 @@ def link_features(self) -> List[LinkFeature]:
4141
A list of LinkFeatures of the pipeline.
4242
4343
"""
44-
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["featureSteps"]
44+
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["featureSteps"]
4545
return [LinkFeature(s["name"], s["config"]) for s in steps]

graphdatascience/model/model.py

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, Optional
4+
from typing import Any, Dict
55

66
from pandas import DataFrame, Series
77

@@ -23,33 +23,30 @@ def __init__(self, name: str, query_runner: QueryRunner, server_version: ServerV
2323
def _endpoint_prefix(self) -> str:
2424
pass
2525

26-
def _list_info(self) -> DataFrame:
26+
def _list_info(self) -> Series[Any]:
27+
params = CallParameters(name=self.name())
28+
29+
result: Series[Any]
2730
if self._server_version < ServerVersion(2, 5, 0):
28-
query = "CALL gds.beta.model.list($name)"
31+
result = self._query_runner.call_procedure(
32+
"gds.beta.model.list", params=params, custom_error=False
33+
).squeeze()
2934
else:
30-
query = """
31-
CALL gds.model.list($name)
32-
YIELD
33-
modelName, modelType, modelInfo,
34-
creationTime, trainConfig, graphSchema,
35-
loaded, stored, published
36-
RETURN
37-
modelName, modelType,
38-
modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
39-
creationTime, trainConfig, graphSchema,
40-
loaded, stored, published, published AS shared
41-
"""
42-
43-
params = {"name": self.name()}
44-
45-
# FIXME use call procedure + do post processing on the client side
46-
# TODO really fixme
47-
info = self._query_runner.run_cypher(query, params, custom_error=False)
48-
49-
if len(info) == 0:
35+
result = self._query_runner.call_procedure("gds.model.list", params=params, custom_error=False).squeeze()
36+
37+
if not result.empty:
38+
# 2.5 compat format
39+
result["modelInfo"] = {
40+
**result["modelInfo"],
41+
"modelName": result["modelName"],
42+
"modelType": result["modelType"],
43+
}
44+
result["shared"] = result["published"]
45+
46+
if result.empty:
5047
raise ValueError(f"There is no '{self.name()}' in the model catalog")
5148

52-
return info
49+
return result
5350

5451
def _estimate_predict(self, predict_mode: str, graph_name: str, config: Dict[str, Any]) -> Series[Any]:
5552
endpoint = f"{self._endpoint_prefix()}{predict_mode}.estimate"
@@ -79,7 +76,7 @@ def type(self) -> str:
7976
The type of the model.
8077
8178
"""
82-
return self._list_info()["modelInfo"][0]["modelType"] # type: ignore
79+
return self._list_info()["modelInfo"]["modelType"] # type: ignore
8380

8481
def train_config(self) -> Series[Any]:
8582
"""
@@ -89,7 +86,7 @@ def train_config(self) -> Series[Any]:
8986
The train config of the model.
9087
9188
"""
92-
train_config: Series[Any] = Series(self._list_info()["trainConfig"][0])
89+
train_config: Series[Any] = Series(self._list_info()["trainConfig"])
9390
return train_config
9491

9592
def graph_schema(self) -> Series[Any]:
@@ -100,7 +97,7 @@ def graph_schema(self) -> Series[Any]:
10097
The graph schema of the model.
10198
10299
"""
103-
graph_schema: Series[Any] = Series(self._list_info()["graphSchema"][0])
100+
graph_schema: Series[Any] = Series(self._list_info()["graphSchema"])
104101
return graph_schema
105102

106103
def loaded(self) -> bool:
@@ -111,7 +108,7 @@ def loaded(self) -> bool:
111108
True if the model is loaded in memory, False otherwise.
112109
113110
"""
114-
return self._list_info()["loaded"].squeeze() # type: ignore
111+
return self._list_info()["loaded"] # type: ignore
115112

116113
def stored(self) -> bool:
117114
"""
@@ -121,7 +118,7 @@ def stored(self) -> bool:
121118
True if the model is stored on disk, False otherwise.
122119
123120
"""
124-
return self._list_info()["stored"].squeeze() # type: ignore
121+
return self._list_info()["stored"] # type: ignore
125122

126123
def creation_time(self) -> Any: # neo4j.time.DateTime not exported
127124
"""
@@ -131,7 +128,7 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
131128
The creation time of the model.
132129
133130
"""
134-
return self._list_info()["creationTime"].squeeze()
131+
return self._list_info()["creationTime"]
135132

136133
def shared(self) -> bool:
137134
"""
@@ -141,7 +138,7 @@ def shared(self) -> bool:
141138
True if the model is shared, False otherwise.
142139
143140
"""
144-
return self._list_info()["shared"].squeeze() # type: ignore
141+
return self._list_info()["shared"] # type: ignore
145142

146143
@compatible_with("published", min_inclusive=ServerVersion(2, 5, 0))
147144
def published(self) -> bool:
@@ -152,17 +149,17 @@ def published(self) -> bool:
152149
True if the model is published, False otherwise.
153150
154151
"""
155-
return self._list_info()["published"].squeeze() # type: ignore
152+
return self._list_info()["published"] # type: ignore
156153

157-
def model_info(self) -> Series[Any]:
154+
def model_info(self) -> Dict[str, Any]:
158155
"""
159156
Get the model info of the model.
160157
161158
Returns:
162159
The model info of the model.
163160
164161
"""
165-
return Series(self._list_info()["modelInfo"].squeeze())
162+
return Series(self._list_info()["modelInfo"]) # type: ignore
166163

167164
def exists(self) -> bool:
168165
"""
@@ -199,11 +196,11 @@ def drop(self, failIfMissing: bool = False) -> Series[Any]:
199196
"gds.beta.model.drop", params=params, custom_error=False
200197
).squeeze()
201198
else:
202-
result: Optional[Series[Any]] = self._query_runner.call_procedure(
199+
result: Series[Any] = self._query_runner.call_procedure(
203200
"gds.model.drop", params=params, custom_error=False
204201
).squeeze()
205202

206-
if result is None:
203+
if result.empty:
207204
return Series()
208205

209206
# modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo
@@ -223,7 +220,7 @@ def metrics(self) -> Series[Any]:
223220
The metrics of the model.
224221
225222
"""
226-
model_info = self._list_info()["modelInfo"][0]
223+
model_info = self._list_info()["modelInfo"]
227224
metrics: Series[Any] = Series(model_info["metrics"])
228225
return metrics
229226

graphdatascience/model/node_classification_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def classes(self) -> List[int]:
6262
The classes of the model.
6363
6464
"""
65-
return self._list_info()["modelInfo"][0]["classes"] # type: ignore
65+
return self._list_info()["modelInfo"]["classes"] # type: ignore
6666

6767
def feature_properties(self) -> List[str]:
6868
"""
@@ -72,5 +72,5 @@ def feature_properties(self) -> List[str]:
7272
The feature properties of the model.
7373
7474
"""
75-
features: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["featureProperties"]
75+
features: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["featureProperties"]
7676
return [f["feature"] for f in features]

graphdatascience/model/node_regression_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,5 @@ def feature_properties(self) -> List[str]:
2121
The feature properties of the model.
2222
2323
"""
24-
features: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["featureProperties"]
24+
features: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["featureProperties"]
2525
return [f["feature"] for f in features]

graphdatascience/model/pipeline_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def best_parameters(self) -> "Series[Any]":
7575
The best parameters for the pipeline model.
7676
7777
"""
78-
best_params: Dict[str, Any] = self._list_info()["modelInfo"][0]["bestParameters"]
78+
best_params: Dict[str, Any] = self._list_info()["modelInfo"]["bestParameters"]
7979
return Series(best_params)
8080

8181
def node_property_steps(self) -> List[NodePropertyStep]:
@@ -86,7 +86,7 @@ def node_property_steps(self) -> List[NodePropertyStep]:
8686
The node property steps for the pipeline model.
8787
8888
"""
89-
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"][0]["pipeline"]["nodePropertySteps"]
89+
steps: List[Dict[str, Any]] = self._list_info()["modelInfo"]["pipeline"]["nodePropertySteps"]
9090
return [NodePropertyStep(s["name"], s["config"]) for s in steps]
9191

9292
def metrics(self) -> "Series[Any]":
@@ -97,6 +97,6 @@ def metrics(self) -> "Series[Any]":
9797
The metrics for the pipeline model.
9898
9999
"""
100-
model_metrics: Dict[str, Any] = self._list_info()["modelInfo"][0]["metrics"]
100+
model_metrics: Dict[str, Any] = self._list_info()["modelInfo"]["metrics"]
101101
metric_scores: Dict[str, MetricScores] = {k: MetricScores.create(v) for k, v in (model_metrics.items())}
102102
return Series(metric_scores)

graphdatascience/tests/integration/test_model_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,18 @@ def test_model_get_no_model(gds: GraphDataScience) -> None:
305305
gds.model.get("no_model")
306306

307307

308+
def test_missing_model_info(gds: GraphDataScience) -> None:
309+
model = GraphSageModel("ghost-model", gds._query_runner, gds.server_version())
310+
with pytest.raises(ValueError, match="There is no 'ghost-model' in the model catalog"):
311+
model.model_info()
312+
313+
314+
def test_missing_model_drop(gds: GraphDataScience) -> None:
315+
model = GraphSageModel("ghost-model", gds._query_runner, gds.server_version())
316+
317+
assert model.drop(failIfMissing=False).empty
318+
319+
308320
@pytest.mark.model_store_location
309321
@pytest.mark.filterwarnings("ignore: The query used a deprecated procedure.")
310322
def test_deprecated_model_Functions_still_work(gds: GraphDataScience, gs_model: GraphSageModel) -> None:

0 commit comments

Comments
 (0)