Skip to content

Commit 8e47d2e

Browse files
committed
Add query mode to call_procedure and call_function
1 parent d4513ef commit 8e47d2e

File tree

11 files changed

+62
-21
lines changed

11 files changed

+62
-21
lines changed

graphdatascience/query_runner/arrow_query_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pandas import DataFrame
77

88
from graphdatascience.query_runner.arrow_authentication import ArrowAuthentication
9+
from graphdatascience.query_runner.query_mode import QueryMode
910
from graphdatascience.retry_utils.retry_config import RetryConfig
1011

1112
from ..call_parameters import CallParameters
@@ -86,6 +87,7 @@ def call_procedure(
8687
params: Optional[CallParameters] = None,
8788
yields: Optional[list[str]] = None,
8889
database: Optional[str] = None,
90+
mode: QueryMode = QueryMode.READ,
8991
logging: bool = False,
9092
retryable: bool = False,
9193
custom_error: bool = True,

graphdatascience/query_runner/neo4j_query_runner.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import neo4j
1010
from pandas import DataFrame
1111

12+
from graphdatascience.query_runner.query_mode import QueryMode
13+
1214
from ..call_parameters import CallParameters
1315
from ..error.endpoint_suggester import generate_suggestive_error_message
1416
from ..error.gds_not_installed import GdsNotFound
@@ -238,14 +240,17 @@ def call_function(self, endpoint: str, params: Optional[CallParameters] = None,
238240
query = f"RETURN {endpoint}({params.placeholder_str()})"
239241

240242
# we can use retryable cypher as we expect all gds functions to be idempotent
241-
return self.run_retryable_cypher(query, params, custom_error=custom_error).squeeze()
243+
return self.run_retryable_cypher(
244+
query, params, custom_error=custom_error, routing=neo4j.RoutingControl.READ
245+
).squeeze()
242246

243247
def call_procedure(
244248
self,
245249
endpoint: str,
246250
params: Optional[CallParameters] = None,
247251
yields: Optional[list[str]] = None,
248252
database: Optional[str] = None,
253+
mode: QueryMode = QueryMode.READ,
249254
logging: bool = False,
250255
retryable: bool = False,
251256
custom_error: bool = True,
@@ -258,8 +263,7 @@ def call_procedure(
258263

259264
def run_cypher_query() -> DataFrame:
260265
if retryable:
261-
routing = neo4j.RoutingControl.WRITE if "write" in endpoint else neo4j.RoutingControl.READ
262-
return self.run_retryable_cypher(query, params, database, custom_error, routing=routing)
266+
return self.run_retryable_cypher(query, params, database, custom_error, routing=mode.neo4j_mode())
263267
else:
264268
return self.run_cypher(query, params, database, custom_error)
265269

graphdatascience/query_runner/protocol/project_protocols.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,12 @@ def is_not_done(result: DataFrame) -> bool:
138138

139139
# We need to pin the driver to a specific cluster member
140140
response = query_runner.call_procedure(
141-
ProtocolVersion.V3.versioned_procedure_name(endpoint), params, yields, database, logging, False
141+
ProtocolVersion.V3.versioned_procedure_name(endpoint),
142+
params,
143+
yields,
144+
database,
145+
logging=logging,
146+
custom_error=False,
142147
).squeeze()
143148
member_host = response["host"]
144149
member_port = response["port"] if ("port" in response.index) else 7687

graphdatascience/query_runner/protocol/write_protocols.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from graphdatascience.call_parameters import CallParameters
99
from graphdatascience.query_runner.protocol.status import Status
10+
from graphdatascience.query_runner.query_mode import QueryMode
1011
from graphdatascience.query_runner.query_runner import QueryRunner
1112
from graphdatascience.query_runner.termination_flag import TerminationFlag
1213
from graphdatascience.retry_utils.retry_utils import before_log
@@ -76,6 +77,7 @@ def run_write_back(
7677
retryable=False,
7778
database=None,
7879
logging=False,
80+
mode=QueryMode.WRITE,
7981
custom_error=False,
8082
)
8183

@@ -115,6 +117,7 @@ def run_write_back(
115117
retryable=False,
116118
database=None,
117119
logging=False,
120+
mode=QueryMode.WRITE,
118121
custom_error=False,
119122
)
120123

@@ -161,6 +164,7 @@ def write_fn() -> DataFrame:
161164
yields,
162165
retryable=True,
163166
logging=False,
167+
mode=QueryMode.WRITE,
164168
custom_error=False,
165169
)
166170

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from enum import Enum
2+
3+
import neo4j
4+
import neo4j.routing
5+
6+
7+
class QueryMode(str, Enum):
8+
READ = "read"
9+
WRITE = "write"
10+
11+
def neo4j_mode(self) -> neo4j.RoutingControl:
12+
if self == QueryMode.READ:
13+
return neo4j.RoutingControl.READ
14+
elif self == QueryMode.WRITE:
15+
return neo4j.RoutingControl.WRITE
16+
else:
17+
raise ValueError(f"Unknown query mode: {self}")

graphdatascience/query_runner/query_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from pandas import DataFrame
55

6+
from graphdatascience.query_runner.query_mode import QueryMode
7+
68
from ..call_parameters import CallParameters
79
from ..server_version.server_version import ServerVersion
810
from .graph_constructor import GraphConstructor
@@ -16,6 +18,7 @@ def call_procedure(
1618
params: Optional[CallParameters] = None,
1719
yields: Optional[list[str]] = None,
1820
database: Optional[str] = None,
21+
mode: QueryMode = QueryMode.READ,
1922
logging: bool = False,
2023
retryable: bool = False,
2124
custom_error: bool = True,

graphdatascience/query_runner/session_query_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from graphdatascience.query_runner.graph_constructor import GraphConstructor
1010
from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger
11+
from graphdatascience.query_runner.query_mode import QueryMode
1112
from graphdatascience.query_runner.termination_flag import TerminationFlag
1213
from graphdatascience.server_version.server_version import ServerVersion
1314

@@ -75,6 +76,7 @@ def call_procedure(
7576
params: Optional[CallParameters] = None,
7677
yields: Optional[list[str]] = None,
7778
database: Optional[str] = None,
79+
mode: QueryMode = QueryMode.READ,
7880
logging: bool = False,
7981
retryable: bool = False,
8082
custom_error: bool = True,
@@ -206,7 +208,7 @@ def _remote_write_back(
206208
config["writeToResultStore"] = True
207209

208210
gds_write_result = self._gds_query_runner.call_procedure(
209-
endpoint, params, yields, database, logging, custom_error
211+
endpoint, params, yields, database=database, logging=logging, custom_error=custom_error
210212
)
211213
terminationFlag.assert_running()
212214

graphdatascience/query_runner/standalone_session_query_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from graphdatascience import QueryRunner, ServerVersion
88
from graphdatascience.call_parameters import CallParameters
99
from graphdatascience.query_runner.graph_constructor import GraphConstructor
10+
from graphdatascience.query_runner.query_mode import QueryMode
1011

1112

1213
class StandaloneSessionQueryRunner(QueryRunner):
@@ -19,6 +20,7 @@ def call_procedure(
1920
params: Optional[CallParameters] = None,
2021
yields: Optional[list[str]] = None,
2122
database: Optional[str] = None,
23+
mode: QueryMode = QueryMode.READ,
2224
logging: bool = False,
2325
retryable: bool = False,
2426
custom_error: bool = True,

graphdatascience/tests/integration/test_remote_graph_ops.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def run_around_tests(gds_with_cloud_setup: AuraGraphDataScience) -> Generator[No
3636
gds_with_cloud_setup.graph.get(graph_name).drop(failIfMissing=True)
3737

3838

39-
@pytest.mark.cloud_architecture
39+
# @pytest.mark.cloud_architecture
4040
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
4141
def test_remote_projection(gds_with_cloud_setup: AuraGraphDataScience) -> None:
4242
G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
@@ -87,7 +87,7 @@ def test_remote_projection_and_writeback_custom_database_name(gds_with_cloud_set
8787
gds_with_cloud_setup.set_database("neo4j")
8888

8989

90-
@pytest.mark.cloud_architecture
90+
# @pytest.mark.cloud_architecture
9191
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
9292
def test_remote_projection_with_small_batch_size(gds_with_cloud_setup: AuraGraphDataScience) -> None:
9393
G, result = gds_with_cloud_setup.graph.project(
@@ -98,7 +98,7 @@ def test_remote_projection_with_small_batch_size(gds_with_cloud_setup: AuraGraph
9898
assert result["nodeCount"] == 3
9999

100100

101-
@pytest.mark.cloud_architecture
101+
# @pytest.mark.cloud_architecture
102102
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
103103
def test_remote_write_back_page_rank(gds_with_cloud_setup: AuraGraphDataScience) -> None:
104104
G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
@@ -108,7 +108,7 @@ def test_remote_write_back_page_rank(gds_with_cloud_setup: AuraGraphDataScience)
108108
assert result["nodePropertiesWritten"] == 3
109109

110110

111-
@pytest.mark.cloud_architecture
111+
# @pytest.mark.cloud_architecture
112112
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
113113
def test_remote_write_back_node_similarity(gds_with_cloud_setup: AuraGraphDataScience) -> None:
114114
G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
@@ -120,7 +120,7 @@ def test_remote_write_back_node_similarity(gds_with_cloud_setup: AuraGraphDataSc
120120
assert result["relationshipsWritten"] == 2
121121

122122

123-
@pytest.mark.cloud_architecture
123+
# @pytest.mark.cloud_architecture
124124
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
125125
def test_remote_write_back_node_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None:
126126
G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
@@ -130,7 +130,7 @@ def test_remote_write_back_node_properties(gds_with_cloud_setup: AuraGraphDataSc
130130
assert result["propertiesWritten"] == 3
131131

132132

133-
@pytest.mark.cloud_architecture
133+
# @pytest.mark.cloud_architecture
134134
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
135135
def test_remote_write_back_node_properties_with_multiple_labels(gds_with_cloud_setup: AuraGraphDataScience) -> None:
136136
G, result = gds_with_cloud_setup.graph.project(
@@ -143,7 +143,7 @@ def test_remote_write_back_node_properties_with_multiple_labels(gds_with_cloud_s
143143
assert result["nodePropertiesWritten"] == 3
144144

145145

146-
@pytest.mark.cloud_architecture
146+
# @pytest.mark.cloud_architecture
147147
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
148148
def test_remote_write_back_node_properties_with_select_labels(gds_with_cloud_setup: AuraGraphDataScience) -> None:
149149
G, result = gds_with_cloud_setup.graph.project(
@@ -157,7 +157,7 @@ def test_remote_write_back_node_properties_with_select_labels(gds_with_cloud_set
157157
assert result["propertiesWritten"] == 1
158158

159159

160-
@pytest.mark.cloud_architecture
160+
# @pytest.mark.cloud_architecture
161161
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
162162
def test_remote_write_back_node_label(gds_with_cloud_setup: AuraGraphDataScience) -> None:
163163
G, result = gds_with_cloud_setup.graph.project(GRAPH_NAME, "MATCH (n)-->(m) RETURN gds.graph.project.remote(n, m)")
@@ -166,7 +166,7 @@ def test_remote_write_back_node_label(gds_with_cloud_setup: AuraGraphDataScience
166166
assert result["nodeLabelsWritten"] == 3
167167

168168

169-
@pytest.mark.cloud_architecture
169+
# @pytest.mark.cloud_architecture
170170
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
171171
def test_remote_write_back_relationship_topology(gds_with_cloud_setup: AuraGraphDataScience) -> None:
172172
G, result = gds_with_cloud_setup.graph.project(
@@ -177,7 +177,7 @@ def test_remote_write_back_relationship_topology(gds_with_cloud_setup: AuraGraph
177177
assert result["relationshipsWritten"] == 4
178178

179179

180-
@pytest.mark.cloud_architecture
180+
# @pytest.mark.cloud_architecture
181181
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
182182
def test_remote_write_back_relationship_property(gds_with_cloud_setup: AuraGraphDataScience) -> None:
183183
G, result = gds_with_cloud_setup.graph.project(
@@ -190,7 +190,7 @@ def test_remote_write_back_relationship_property(gds_with_cloud_setup: AuraGraph
190190
assert result["relationshipsWritten"] == 4
191191

192192

193-
@pytest.mark.cloud_architecture
193+
# @pytest.mark.cloud_architecture
194194
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
195195
def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGraphDataScience) -> None:
196196
G, result = gds_with_cloud_setup.graph.project(
@@ -207,7 +207,7 @@ def test_remote_write_back_relationship_properties(gds_with_cloud_setup: AuraGra
207207
assert result["relationshipsWritten"] == 4
208208

209209

210-
@pytest.mark.cloud_architecture
210+
# @pytest.mark.cloud_architecture
211211
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
212212
def test_remote_write_back_relationship_property_from_pathfinding_algo(
213213
gds_with_cloud_setup: AuraGraphDataScience,
@@ -223,7 +223,7 @@ def test_remote_write_back_relationship_property_from_pathfinding_algo(
223223
assert result["relationshipsWritten"] == 1
224224

225225

226-
@pytest.mark.cloud_architecture
226+
# @pytest.mark.cloud_architecture
227227
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
228228
def test_empty_graph_write_back(
229229
gds_with_cloud_setup: AuraGraphDataScience,

graphdatascience/tests/integration/test_remote_util_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def G(gds_with_cloud_setup: AuraGraphDataScience) -> Generator[Graph, None, None
5050
gds_with_cloud_setup.run_cypher("MATCH (n) DETACH DELETE n")
5151

5252

53-
@pytest.mark.cloud_architecture
53+
# @pytest.mark.cloud_architecture
5454
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
5555
def test_remote_util_as_node(gds_with_cloud_setup: AuraGraphDataScience) -> None:
5656
id = gds_with_cloud_setup.find_node_id(["Location"], {"name": "A"})
5757
result = gds_with_cloud_setup.util.asNode(id)
5858
assert result["name"] == "A"
5959

6060

61-
@pytest.mark.cloud_architecture
61+
# @pytest.mark.cloud_architecture
6262
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
6363
def test_remote_util_as_nodes(gds_with_cloud_setup: AuraGraphDataScience) -> None:
6464
ids = [
@@ -69,7 +69,7 @@ def test_remote_util_as_nodes(gds_with_cloud_setup: AuraGraphDataScience) -> Non
6969
assert len(result) == 2
7070

7171

72-
@pytest.mark.cloud_architecture
72+
# @pytest.mark.cloud_architecture
7373
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 7, 0))
7474
def test_util_nodeProperty(gds_with_cloud_setup: AuraGraphDataScience, G: Graph) -> None:
7575
id = gds_with_cloud_setup.find_node_id(["Location"], {"name": "A"})

0 commit comments

Comments
 (0)