Skip to content

Commit a7136a8

Browse files
committed
Support call_functions in QueryRunner
Allows to route to GDS session
1 parent 32ad2eb commit a7136a8

File tree

5 files changed

+28
-2
lines changed

5 files changed

+28
-2
lines changed

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ def run_cypher(
6464
) -> DataFrame:
6565
return self._fallback_query_runner.run_cypher(query, params, database, custom_error)
6666

67+
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
68+
return self._fallback_query_runner.call_function(endpoint, params)
69+
6770
def call_procedure(
6871
self,
6972
endpoint: str,

graphdatascience/query_runner/aura_db_query_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def run_cypher(
3535
) -> DataFrame:
3636
return self._db_query_runner.run_cypher(query, params, database, custom_error)
3737

38+
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
39+
return self._gds_query_runner.call_function(endpoint, params)
40+
3841
def call_procedure(
3942
self,
4043
endpoint: str,

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ def run_cypher(
131131

132132
return df
133133

134+
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
135+
if params is None:
136+
params = CallParameters()
137+
query = f"RETURN {endpoint}({params.placeholder_str()})"
138+
139+
return self.run_cypher(query, params).squeeze()
140+
134141
def call_procedure(
135142
self,
136143
endpoint: str,

graphdatascience/query_runner/query_runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def call_procedure(
2121
) -> DataFrame:
2222
pass
2323

24+
@abstractmethod
25+
def call_function(self, endpoint: str, params: Optional[CallParameters] = None) -> Any:
26+
pass
27+
2428
@abstractmethod
2529
def run_cypher(
2630
self,

graphdatascience/session/aura_graph_data_science.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from pandas import DataFrame
44

55
from graphdatascience.call_builder import IndirectCallBuilder
6-
from graphdatascience.endpoints import AlphaRemoteEndpoints, BetaEndpoints, DirectEndpoints
6+
from graphdatascience.endpoints import (
7+
AlphaRemoteEndpoints,
8+
BetaEndpoints,
9+
DirectEndpoints,
10+
)
711
from graphdatascience.error.uncallable_namespace import UncallableNamespace
812
from graphdatascience.graph.graph_remote_proc_runner import GraphRemoteProcRunner
913
from graphdatascience.query_runner.arrow_query_runner import ArrowQueryRunner
@@ -12,6 +16,7 @@
1216
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1317
from graphdatascience.server_version.server_version import ServerVersion
1418
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
19+
from graphdatascience.utils.util_remote_proc_runner import UtilRemoteProcRunner
1520

1621

1722
class AuraGraphDataScience(DirectEndpoints, UncallableNamespace):
@@ -75,7 +80,7 @@ def __init__(
7580

7681
self._delete_fn = delete_fn
7782

78-
super().__init__(self._query_runner, "gds", self._server_version)
83+
super().__init__(self._query_runner, namespace="gds", server_version=self._server_version)
7984

8085
def run_cypher(
8186
self, query: str, params: Optional[Dict[str, Any]] = None, database: Optional[str] = None
@@ -102,6 +107,10 @@ def run_cypher(
102107
def graph(self) -> GraphRemoteProcRunner:
103108
return GraphRemoteProcRunner(self._query_runner, f"{self._namespace}.graph", self._server_version)
104109

110+
@property
111+
def util(self) -> UtilRemoteProcRunner:
112+
return UtilRemoteProcRunner(self._query_runner, f"{self._namespace}.util", self._server_version)
113+
105114
@property
106115
def alpha(self) -> AlphaRemoteEndpoints:
107116
return AlphaRemoteEndpoints(self._query_runner, "gds.alpha", self._server_version)

0 commit comments

Comments
 (0)