11from __future__ import annotations
22
33from abc import ABC , abstractmethod
4- from typing import Any , Dict , Optional
4+ from typing import Any , Dict
55
66from pandas import DataFrame , Series
77
@@ -23,33 +23,30 @@ def __init__(self, name: str, query_runner: QueryRunner, server_version: ServerV
2323 def _endpoint_prefix (self ) -> str :
2424 pass
2525
26- def _list_info (self ) -> DataFrame :
26+ def _list_info (self ) -> Series [Any ]:
27+ params = CallParameters (name = self .name ())
28+
29+ result : Series [Any ]
2730 if self ._server_version < ServerVersion (2 , 5 , 0 ):
28- query = "CALL gds.beta.model.list($name)"
31+ result = self ._query_runner .call_procedure (
32+ "gds.beta.model.list" , params = params , custom_error = False
33+ ).squeeze ()
2934 else :
30- query = """
31- CALL gds.model.list($name)
32- YIELD
33- modelName, modelType, modelInfo,
34- creationTime, trainConfig, graphSchema,
35- loaded, stored, published
36- RETURN
37- modelName, modelType,
38- modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo,
39- creationTime, trainConfig, graphSchema,
40- loaded, stored, published, published AS shared
41- """
42-
43- params = {"name" : self .name ()}
44-
45- # FIXME use call procedure + do post processing on the client side
46- # TODO really fixme
47- info = self ._query_runner .run_cypher (query , params , custom_error = False )
48-
49- if len (info ) == 0 :
35+ result = self ._query_runner .call_procedure ("gds.model.list" , params = params , custom_error = False ).squeeze ()
36+
37+ if not result .empty :
38+ # 2.5 compat format
39+ result ["modelInfo" ] = {
40+ ** result ["modelInfo" ],
41+ "modelName" : result ["modelName" ],
42+ "modelType" : result ["modelType" ],
43+ }
44+ result ["shared" ] = result ["published" ]
45+
46+ if result .empty :
5047 raise ValueError (f"There is no '{ self .name ()} ' in the model catalog" )
5148
52- return info
49+ return result
5350
5451 def _estimate_predict (self , predict_mode : str , graph_name : str , config : Dict [str , Any ]) -> Series [Any ]:
5552 endpoint = f"{ self ._endpoint_prefix ()} { predict_mode } .estimate"
@@ -79,7 +76,7 @@ def type(self) -> str:
7976 The type of the model.
8077
8178 """
82- return self ._list_info ()["modelInfo" ][0 ][ "modelType" ] # type: ignore
79+ return self ._list_info ()["modelInfo" ]["modelType" ] # type: ignore
8380
8481 def train_config (self ) -> Series [Any ]:
8582 """
@@ -89,7 +86,7 @@ def train_config(self) -> Series[Any]:
8986 The train config of the model.
9087
9188 """
92- train_config : Series [Any ] = Series (self ._list_info ()["trainConfig" ][ 0 ] )
89+ train_config : Series [Any ] = Series (self ._list_info ()["trainConfig" ])
9390 return train_config
9491
9592 def graph_schema (self ) -> Series [Any ]:
@@ -100,7 +97,7 @@ def graph_schema(self) -> Series[Any]:
10097 The graph schema of the model.
10198
10299 """
103- graph_schema : Series [Any ] = Series (self ._list_info ()["graphSchema" ][ 0 ] )
100+ graph_schema : Series [Any ] = Series (self ._list_info ()["graphSchema" ])
104101 return graph_schema
105102
106103 def loaded (self ) -> bool :
@@ -111,7 +108,7 @@ def loaded(self) -> bool:
111108 True if the model is loaded in memory, False otherwise.
112109
113110 """
114- return self ._list_info ()["loaded" ]. squeeze () # type: ignore
111+ return self ._list_info ()["loaded" ] # type: ignore
115112
116113 def stored (self ) -> bool :
117114 """
@@ -121,7 +118,7 @@ def stored(self) -> bool:
121118 True if the model is stored on disk, False otherwise.
122119
123120 """
124- return self ._list_info ()["stored" ]. squeeze () # type: ignore
121+ return self ._list_info ()["stored" ] # type: ignore
125122
126123 def creation_time (self ) -> Any : # neo4j.time.DateTime not exported
127124 """
@@ -131,7 +128,7 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
131128 The creation time of the model.
132129
133130 """
134- return self ._list_info ()["creationTime" ]. squeeze ()
131+ return self ._list_info ()["creationTime" ]
135132
136133 def shared (self ) -> bool :
137134 """
@@ -141,7 +138,7 @@ def shared(self) -> bool:
141138 True if the model is shared, False otherwise.
142139
143140 """
144- return self ._list_info ()["shared" ]. squeeze () # type: ignore
141+ return self ._list_info ()["shared" ] # type: ignore
145142
146143 @compatible_with ("published" , min_inclusive = ServerVersion (2 , 5 , 0 ))
147144 def published (self ) -> bool :
@@ -152,17 +149,17 @@ def published(self) -> bool:
152149 True if the model is published, False otherwise.
153150
154151 """
155- return self ._list_info ()["published" ]. squeeze () # type: ignore
152+ return self ._list_info ()["published" ] # type: ignore
156153
157- def model_info (self ) -> Series [ Any ]:
154+ def model_info (self ) -> Dict [ str , Any ]:
158155 """
159156 Get the model info of the model.
160157
161158 Returns:
162159 The model info of the model.
163160
164161 """
165- return Series (self ._list_info ()["modelInfo" ]. squeeze ())
162+ return Series (self ._list_info ()["modelInfo" ]) # type: ignore
166163
167164 def exists (self ) -> bool :
168165 """
@@ -199,11 +196,11 @@ def drop(self, failIfMissing: bool = False) -> Series[Any]:
199196 "gds.beta.model.drop" , params = params , custom_error = False
200197 ).squeeze ()
201198 else :
202- result : Optional [ Series [Any ] ] = self ._query_runner .call_procedure (
199+ result : Series [Any ] = self ._query_runner .call_procedure (
203200 "gds.model.drop" , params = params , custom_error = False
204201 ).squeeze ()
205202
206- if result is None :
203+ if result . empty :
207204 return Series ()
208205
209206 # modelInfo {.*, modelName: modelName, modelType: modelType} AS modelInfo
@@ -223,7 +220,7 @@ def metrics(self) -> Series[Any]:
223220 The metrics of the model.
224221
225222 """
226- model_info = self ._list_info ()["modelInfo" ][ 0 ]
223+ model_info = self ._list_info ()["modelInfo" ]
227224 metrics : Series [Any ] = Series (model_info ["metrics" ])
228225 return metrics
229226
0 commit comments