1+ from __future__ import annotations
2+
13from abc import ABC , abstractmethod
24from typing import Any , Dict
35
@@ -21,34 +23,32 @@ def __init__(self, name: str, query_runner: QueryRunner, server_version: ServerV
2123 def _endpoint_prefix (self ) -> str :
2224 pass
2325
24- def _list_info (self ) -> DataFrame :
26+ def _list_info (self ) -> Series [Any ]:
27+ params = CallParameters (name = self .name ())
28+
29+ result : Series [Any ]
2530 if self ._server_version < ServerVersion (2 , 5 , 0 ):
26- query = "CALL gds.beta.model.list($name)"
31+ result = self ._query_runner .call_procedure (
32+ "gds.beta.model.list" , params = params , custom_error = False
33+ ).squeeze ()
2734 else :
28- query = """
29- CALL gds.model.list($name)
30- YIELD
31- modelName, modelType, modelInfo,
32- creationTime, trainConfig, graphSchema,
33- loaded, stored, published
34- RETURN
35- modelName, modelType,
36- modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
37- creationTime, trainConfig, graphSchema,
38- loaded, stored, published, published AS shared
39- """
40-
41- params = {"name" : self .name ()}
42-
43- # FIXME use call procedure + do post processing on the client side
44- info = self ._query_runner .run_cypher (query , params , custom_error = False )
45-
46- if len (info ) == 0 :
35+ result = self ._query_runner .call_procedure ("gds.model.list" , params = params , custom_error = False ).squeeze ()
36+
37+ if not result .empty :
38+ # 2.5 compat format
39+ result ["modelInfo" ] = {
40+ ** result ["modelInfo" ],
41+ "modelName" : result ["modelName" ],
42+ "modelType" : result ["modelType" ],
43+ }
44+ result ["shared" ] = result ["published" ]
45+
46+ if result .empty :
4747 raise ValueError (f"There is no '{ self .name ()} ' in the model catalog" )
4848
49- return info
49+ return result
5050
51- def _estimate_predict (self , predict_mode : str , graph_name : str , config : Dict [str , Any ]) -> " Series[Any]" :
51+ def _estimate_predict (self , predict_mode : str , graph_name : str , config : Dict [str , Any ]) -> Series [Any ]:
5252 endpoint = f"{ self ._endpoint_prefix ()} { predict_mode } .estimate"
5353 config ["modelName" ] = self .name ()
5454 params = CallParameters (graph_name = graph_name , config = config )
@@ -76,28 +76,28 @@ def type(self) -> str:
7676 The type of the model.
7777
7878 """
79- return self ._list_info ()["modelInfo" ][0 ][ "modelType" ] # type: ignore
79+ return self ._list_info ()["modelInfo" ]["modelType" ] # type: ignore
8080
81- def train_config (self ) -> " Series[Any]" :
81+ def train_config (self ) -> Series [Any ]:
8282 """
8383 Get the train config of the model.
8484
8585 Returns:
8686 The train config of the model.
8787
8888 """
89- train_config : " Series[Any]" = Series (self ._list_info ()["trainConfig" ][ 0 ])
89+ train_config : Series [Any ] = Series (self ._list_info ()["trainConfig" ])
9090 return train_config
9191
92- def graph_schema (self ) -> " Series[Any]" :
92+ def graph_schema (self ) -> Series [Any ]:
9393 """
9494 Get the graph schema of the model.
9595
9696 Returns:
9797 The graph schema of the model.
9898
9999 """
100- graph_schema : " Series[Any]" = Series (self ._list_info ()["graphSchema" ][ 0 ])
100+ graph_schema : Series [Any ] = Series (self ._list_info ()["graphSchema" ])
101101 return graph_schema
102102
103103 def loaded (self ) -> bool :
@@ -108,7 +108,7 @@ def loaded(self) -> bool:
108108 True if the model is loaded in memory, False otherwise.
109109
110110 """
111- return self ._list_info ()["loaded" ]. squeeze () # type: ignore
111+ return self ._list_info ()["loaded" ] # type: ignore
112112
113113 def stored (self ) -> bool :
114114 """
@@ -118,7 +118,7 @@ def stored(self) -> bool:
118118 True if the model is stored on disk, False otherwise.
119119
120120 """
121- return self ._list_info ()["stored" ]. squeeze () # type: ignore
121+ return self ._list_info ()["stored" ] # type: ignore
122122
123123 def creation_time (self ) -> Any : # neo4j.time.DateTime not exported
124124 """
@@ -128,7 +128,7 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
128128 The creation time of the model.
129129
130130 """
131- return self ._list_info ()["creationTime" ]. squeeze ()
131+ return self ._list_info ()["creationTime" ]
132132
133133 def shared (self ) -> bool :
134134 """
@@ -138,7 +138,7 @@ def shared(self) -> bool:
138138 True if the model is shared, False otherwise.
139139
140140 """
141- return self ._list_info ()["shared" ]. squeeze () # type: ignore
141+ return self ._list_info ()["shared" ] # type: ignore
142142
143143 @compatible_with ("published" , min_inclusive = ServerVersion (2 , 5 , 0 ))
144144 def published (self ) -> bool :
@@ -149,17 +149,17 @@ def published(self) -> bool:
149149 True if the model is published, False otherwise.
150150
151151 """
152- return self ._list_info ()["published" ]. squeeze () # type: ignore
152+ return self ._list_info ()["published" ] # type: ignore
153153
154- def model_info (self ) -> "Series[ Any]" :
154+ def model_info (self ) -> Dict [ str , Any ]:
155155 """
156156 Get the model info of the model.
157157
158158 Returns:
159159 The model info of the model.
160160
161161 """
162- return Series (self ._list_info ()["modelInfo" ]. squeeze ())
162+ return Series (self ._list_info ()["modelInfo" ]) # type: ignore
163163
164164 def exists (self ) -> bool :
165165 """
@@ -179,7 +179,7 @@ def exists(self) -> bool:
179179 endpoint = endpoint , params = params , yields = yields , custom_error = False
180180 ).squeeze ()
181181
182- def drop (self , failIfMissing : bool = False ) -> " Series[Any]" :
182+ def drop (self , failIfMissing : bool = False ) -> Series [Any ]:
183183 """
184184 Drop the model.
185185
@@ -190,36 +190,38 @@ def drop(self, failIfMissing: bool = False) -> "Series[Any]":
190190 The result of the drop operation.
191191
192192 """
193+ params = CallParameters (model_name = self ._name , fail_if_missing = failIfMissing )
193194 if self ._server_version < ServerVersion (2 , 5 , 0 ):
194- query = "CALL gds.beta.model.drop($model_name, $fail_if_missing)"
195+ return self ._query_runner .call_procedure ( # type: ignore
196+ "gds.beta.model.drop" , params = params , custom_error = False
197+ ).squeeze ()
195198 else :
196- query = """
197- CALL gds.model.drop($model_name, $fail_if_missing)
198- YIELD
199- modelName, modelType, modelInfo,
200- creationTime, trainConfig, graphSchema,
201- loaded, stored, published
202- RETURN
203- modelName, modelType,
204- modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
205- creationTime, trainConfig, graphSchema,
206- loaded, stored, published, published AS shared
207- """
208-
209- params = {"model_name" : self ._name , "fail_if_missing" : failIfMissing }
210- # FIXME use call procedure + do post processing on the client side
211- return self ._query_runner .run_cypher (query , params , custom_error = False ).squeeze () # type: ignore
212-
213- def metrics (self ) -> "Series[Any]" :
199+ result : Series [Any ] = self ._query_runner .call_procedure (
200+ "gds.model.drop" , params = params , custom_error = False
201+ ).squeeze ()
202+
203+ if result .empty :
204+ return Series ()
205+
206+ # modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo
207+ result ["modelInfo" ] = {
208+ ** result ["modelInfo" ],
209+ "modelName" : result ["modelName" ],
210+ "modelType" : result ["modelType" ],
211+ }
212+ result ["shared" ] = result ["published" ]
213+ return result
214+
215+ def metrics (self ) -> Series [Any ]:
214216 """
215217 Get the metrics of the model.
216218
217219 Returns:
218220 The metrics of the model.
219221
220222 """
221- model_info = self ._list_info ()["modelInfo" ][ 0 ]
222- metrics : " Series[Any]" = Series (model_info ["metrics" ])
223+ model_info = self ._list_info ()["modelInfo" ]
224+ metrics : Series [Any ] = Series (model_info ["metrics" ])
223225 return metrics
224226
225227 @graph_type_check
@@ -242,7 +244,7 @@ def predict_stream(self, G: Graph, **config: Any) -> DataFrame:
242244 return self ._query_runner .call_procedure (endpoint = endpoint , params = params , logging = True )
243245
244246 @graph_type_check
245- def predict_stream_estimate (self , G : Graph , ** config : Any ) -> " Series[Any]" :
247+ def predict_stream_estimate (self , G : Graph , ** config : Any ) -> Series [Any ]:
246248 """
247249 Estimate the prediction on the given graph using the model and stream the results as DataFrame
248250
@@ -257,7 +259,7 @@ def predict_stream_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
257259 return self ._estimate_predict ("stream" , G .name (), config )
258260
259261 @graph_type_check
260- def predict_mutate (self , G : Graph , ** config : Any ) -> " Series[Any]" :
262+ def predict_mutate (self , G : Graph , ** config : Any ) -> Series [Any ]:
261263 """
262264 Predict on the given graph using the model and mutate the graph with the results.
263265
@@ -278,7 +280,7 @@ def predict_mutate(self, G: Graph, **config: Any) -> "Series[Any]":
278280 ).squeeze ()
279281
280282 @graph_type_check
281- def predict_mutate_estimate (self , G : Graph , ** config : Any ) -> " Series[Any]" :
283+ def predict_mutate_estimate (self , G : Graph , ** config : Any ) -> Series [Any ]:
282284 """
283285 Estimate the memory needed to predict on the given graph using the model.
284286
0 commit comments