1+ from __future__ import annotations
2+
13from abc import ABC , abstractmethod
2- from typing import Any , Dict
4+ from typing import Any , Dict , Optional
35
46from pandas import DataFrame , Series
57
@@ -41,14 +43,15 @@ def _list_info(self) -> DataFrame:
4143 params = {"name" : self .name ()}
4244
4345 # FIXME use call procedure + do post processing on the client side
46+ # TODO really fixme
4447 info = self ._query_runner .run_cypher (query , params , custom_error = False )
4548
4649 if len (info ) == 0 :
4750 raise ValueError (f"There is no '{ self .name ()} ' in the model catalog" )
4851
4952 return info
5053
51- def _estimate_predict (self , predict_mode : str , graph_name : str , config : Dict [str , Any ]) -> " Series[Any]" :
54+ def _estimate_predict (self , predict_mode : str , graph_name : str , config : Dict [str , Any ]) -> Series [Any ]:
5255 endpoint = f"{ self ._endpoint_prefix ()} { predict_mode } .estimate"
5356 config ["modelName" ] = self .name ()
5457 params = CallParameters (graph_name = graph_name , config = config )
@@ -78,26 +81,26 @@ def type(self) -> str:
7881 """
7982 return self ._list_info ()["modelInfo" ][0 ]["modelType" ] # type: ignore
8083
81- def train_config (self ) -> " Series[Any]" :
84+ def train_config (self ) -> Series [Any ]:
8285 """
8386 Get the train config of the model.
8487
8588 Returns:
8689 The train config of the model.
8790
8891 """
89- train_config : " Series[Any]" = Series (self ._list_info ()["trainConfig" ][0 ])
92+ train_config : Series [Any ] = Series (self ._list_info ()["trainConfig" ][0 ])
9093 return train_config
9194
92- def graph_schema (self ) -> " Series[Any]" :
95+ def graph_schema (self ) -> Series [Any ]:
9396 """
9497 Get the graph schema of the model.
9598
9699 Returns:
97100 The graph schema of the model.
98101
99102 """
100- graph_schema : " Series[Any]" = Series (self ._list_info ()["graphSchema" ][0 ])
103+ graph_schema : Series [Any ] = Series (self ._list_info ()["graphSchema" ][0 ])
101104 return graph_schema
102105
103106 def loaded (self ) -> bool :
@@ -151,7 +154,7 @@ def published(self) -> bool:
151154 """
152155 return self ._list_info ()["published" ].squeeze () # type: ignore
153156
154- def model_info (self ) -> " Series[Any]" :
157+ def model_info (self ) -> Series [Any ]:
155158 """
156159 Get the model info of the model.
157160
@@ -179,7 +182,7 @@ def exists(self) -> bool:
179182 endpoint = endpoint , params = params , yields = yields , custom_error = False
180183 ).squeeze ()
181184
182- def drop (self , failIfMissing : bool = False ) -> " Series[Any]" :
185+ def drop (self , failIfMissing : bool = False ) -> Series [Any ]:
183186 """
184187 Drop the model.
185188
@@ -190,27 +193,29 @@ def drop(self, failIfMissing: bool = False) -> "Series[Any]":
190193 The result of the drop operation.
191194
192195 """
196+ params = CallParameters (model_name = self ._name , fail_if_missing = failIfMissing )
193197 if self ._server_version < ServerVersion (2 , 5 , 0 ):
194- query = "CALL gds.beta.model.drop($model_name, $fail_if_missing)"
198+ return self ._query_runner .call_procedure ( # type: ignore
199+ "gds.beta.model.drop" , params = params , custom_error = False
200+ ).squeeze ()
195201 else :
196- query = """
197- CALL gds.model.drop($model_name, $fail_if_missing)
198- YIELD
199- modelName, modelType, modelInfo,
200- creationTime, trainConfig, graphSchema,
201- loaded, stored, published
202- RETURN
203- modelName, modelType,
204- modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
205- creationTime, trainConfig, graphSchema,
206- loaded, stored, published, published AS shared
207- """
202+ result : Optional [Series [Any ]] = self ._query_runner .call_procedure (
203+ "gds.model.drop" , params = params , custom_error = False
204+ ).squeeze ()
208205
209- params = {"model_name" : self ._name , "fail_if_missing" : failIfMissing }
210- # FIXME use call procedure + do post processing on the client side
211- return self ._query_runner .run_cypher (query , params , custom_error = False ).squeeze () # type: ignore
206+ if result is None :
207+ return Series ()
208+
209+ # modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo
210+ result ["modelInfo" ] = {
211+ ** result ["modelInfo" ],
212+ "modelName" : result ["modelName" ],
213+ "modelType" : result ["modelType" ],
214+ }
215+ result ["shared" ] = result ["published" ]
216+ return result
212217
213- def metrics (self ) -> " Series[Any]" :
218+ def metrics (self ) -> Series [Any ]:
214219 """
215220 Get the metrics of the model.
216221
@@ -219,7 +224,7 @@ def metrics(self) -> "Series[Any]":
219224
220225 """
221226 model_info = self ._list_info ()["modelInfo" ][0 ]
222- metrics : " Series[Any]" = Series (model_info ["metrics" ])
227+ metrics : Series [Any ] = Series (model_info ["metrics" ])
223228 return metrics
224229
225230 @graph_type_check
@@ -242,7 +247,7 @@ def predict_stream(self, G: Graph, **config: Any) -> DataFrame:
242247 return self ._query_runner .call_procedure (endpoint = endpoint , params = params , logging = True )
243248
244249 @graph_type_check
245- def predict_stream_estimate (self , G : Graph , ** config : Any ) -> " Series[Any]" :
250+ def predict_stream_estimate (self , G : Graph , ** config : Any ) -> Series [Any ]:
246251 """
247252 Estimate the prediction on the given graph using the model and stream the results as DataFrame
248253
@@ -257,7 +262,7 @@ def predict_stream_estimate(self, G: Graph, **config: Any) -> "Series[Any]":
257262 return self ._estimate_predict ("stream" , G .name (), config )
258263
259264 @graph_type_check
260- def predict_mutate (self , G : Graph , ** config : Any ) -> " Series[Any]" :
265+ def predict_mutate (self , G : Graph , ** config : Any ) -> Series [Any ]:
261266 """
262267 Predict on the given graph using the model and mutate the graph with the results.
263268
@@ -278,7 +283,7 @@ def predict_mutate(self, G: Graph, **config: Any) -> "Series[Any]":
278283 ).squeeze ()
279284
280285 @graph_type_check
281- def predict_mutate_estimate (self , G : Graph , ** config : Any ) -> " Series[Any]" :
286+ def predict_mutate_estimate (self , G : Graph , ** config : Any ) -> Series [Any ]:
282287 """
283288 Estimate the memory needed to predict on the given graph using the model.
284289
0 commit comments