Skip to content

Commit caa0626

Browse files
committed
Allow interruption also if run with progress logging
1 parent e036ab0 commit caa0626

File tree

6 files changed

+113
-14
lines changed

6 files changed

+113
-14
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
* Display progress bar for remote projection and open-ended tasks.
2121
* Improve progress bar by showing the description of the currently running task.
2222
* Allow passing the optional graph filter also as type `str` to `gds.graph.list()` instead of only `Graph`.
23+
* Listen and to SIGINT and SIGTERM and interrupt projection and write-backs based on GDS Session. Note this only works if the query runs in the main thread.
2324

2425

2526
## Other changes

graphdatascience/query_runner/progress/query_progress_logger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _log(
8282

8383
if pbar is not None:
8484
self._finish_pbar(pbar)
85+
# TODO show as cancelled if future was interrupted
8586

8687
def _init_pbar(self, task_with_progress: TaskWithProgress) -> tqdm: # type: ignore
8788
root_task_name = task_with_progress.task_name

graphdatascience/query_runner/protocol/project_protocols.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import signal
21
from abc import ABC, abstractmethod
32
from logging import DEBUG, getLogger
43
from typing import Any, Optional
@@ -8,8 +7,9 @@
87

98
from graphdatascience import QueryRunner
109
from graphdatascience.call_parameters import CallParameters
11-
from graphdatascience.query_runner.protocol.retry_utils import before_log, retry_unless_signal
10+
from graphdatascience.query_runner.protocol.retry_utils import before_log
1211
from graphdatascience.query_runner.protocol.status import Status
12+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1313
from graphdatascience.session.dbms.protocol_version import ProtocolVersion
1414

1515

@@ -27,6 +27,7 @@ def run_projection(
2727
query_runner: QueryRunner,
2828
endpoint: str,
2929
params: CallParameters,
30+
terminationFlag: TerminationFlag,
3031
yields: Optional[list[str]] = None,
3132
database: Optional[str] = None,
3233
logging: bool = False,
@@ -61,6 +62,7 @@ def run_projection(
6162
query_runner: QueryRunner,
6263
endpoint: str,
6364
params: CallParameters,
65+
terminationFlag: TerminationFlag,
6466
yields: Optional[list[str]] = None,
6567
database: Optional[str] = None,
6668
logging: bool = False,
@@ -89,6 +91,7 @@ def run_projection(
8991
query_runner: QueryRunner,
9092
endpoint: str,
9193
params: CallParameters,
94+
terminationFlag: TerminationFlag,
9295
yields: Optional[list[str]] = None,
9396
database: Optional[str] = None,
9497
logging: bool = False,
@@ -118,6 +121,7 @@ def run_projection(
118121
query_runner: QueryRunner,
119122
endpoint: str,
120123
params: CallParameters,
124+
terminationFlag: TerminationFlag,
121125
yields: Optional[list[str]] = None,
122126
database: Optional[str] = None,
123127
logging: bool = False,
@@ -130,10 +134,11 @@ def is_not_done(result: DataFrame) -> bool:
130134

131135
@retry(
132136
before=before_log(f"Projection (graph: `{params['graph_name']}`)", logger, DEBUG),
133-
retry=retry_if_result(is_not_done) and retry_unless_signal([signal.SIGTERM, signal.SIGINT]),
137+
retry=retry_if_result(is_not_done),
134138
wait=wait_incrementing(start=0.2, increment=0.2, max=2),
135139
)
136140
def project_fn() -> DataFrame:
141+
terminationFlag.assert_running()
137142
return query_runner.call_procedure(
138143
ProtocolVersion.V3.versioned_procedure_name(endpoint), params, yields, database, logging, False
139144
)

graphdatascience/query_runner/protocol/write_protocols.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import signal
32
from abc import ABC, abstractmethod
43
from typing import Any, Optional
54

@@ -8,8 +7,9 @@
87

98
from graphdatascience import QueryRunner
109
from graphdatascience.call_parameters import CallParameters
11-
from graphdatascience.query_runner.protocol.retry_utils import before_log, retry_unless_signal
10+
from graphdatascience.query_runner.protocol.retry_utils import before_log
1211
from graphdatascience.query_runner.protocol.status import Status
12+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1313
from graphdatascience.session.dbms.protocol_version import ProtocolVersion
1414

1515

@@ -28,7 +28,11 @@ def write_back_params(
2828

2929
@abstractmethod
3030
def run_write_back(
31-
self, query_runner: QueryRunner, parameters: CallParameters, yields: Optional[list[str]]
31+
self,
32+
query_runner: QueryRunner,
33+
parameters: CallParameters,
34+
yields: Optional[list[str]],
35+
terminationFlag: TerminationFlag,
3236
) -> DataFrame:
3337
"""Executes the write-back procedure"""
3438
pass
@@ -59,7 +63,11 @@ def write_back_params(
5963
)
6064

6165
def run_write_back(
62-
self, query_runner: QueryRunner, parameters: CallParameters, yields: Optional[list[str]]
66+
self,
67+
query_runner: QueryRunner,
68+
parameters: CallParameters,
69+
yields: Optional[list[str]],
70+
terminationFlag: TerminationFlag,
6371
) -> DataFrame:
6472
return query_runner.call_procedure(
6573
ProtocolVersion.V1.versioned_procedure_name("gds.arrow.write"),
@@ -93,7 +101,11 @@ def write_back_params(
93101
)
94102

95103
def run_write_back(
96-
self, query_runner: QueryRunner, parameters: CallParameters, yields: Optional[list[str]]
104+
self,
105+
query_runner: QueryRunner,
106+
parameters: CallParameters,
107+
yields: Optional[list[str]],
108+
terminationFlag: TerminationFlag,
97109
) -> DataFrame:
98110
return query_runner.call_procedure(
99111
ProtocolVersion.V2.versioned_procedure_name("gds.arrow.write"),
@@ -117,7 +129,11 @@ def write_back_params(
117129
return RemoteWriteBackV2().write_back_params(graph_name, job_id, config, arrow_config, database)
118130

119131
def run_write_back(
120-
self, query_runner: QueryRunner, parameters: CallParameters, yields: Optional[list[str]]
132+
self,
133+
query_runner: QueryRunner,
134+
parameters: CallParameters,
135+
yields: Optional[list[str]],
136+
terminationFlag: TerminationFlag,
121137
) -> DataFrame:
122138
def is_not_completed(result: DataFrame) -> bool:
123139
status: str = result.squeeze()["status"]
@@ -126,7 +142,7 @@ def is_not_completed(result: DataFrame) -> bool:
126142
logger = logging.getLogger()
127143

128144
@retry(
129-
retry=retry_if_result(is_not_completed) and retry_unless_signal([signal.SIGTERM, signal.SIGINT]),
145+
retry=retry_if_result(is_not_completed),
130146
wait=wait_incrementing(start=0.2, increment=0.2, max=2),
131147
before=before_log(
132148
f"Write-Back (graph: `{parameters['graphName']}`, jobId: `{parameters['jobId']}`)",
@@ -135,6 +151,7 @@ def is_not_completed(result: DataFrame) -> bool:
135151
),
136152
)
137153
def write_fn() -> DataFrame:
154+
terminationFlag.assert_running()
138155
return query_runner.call_procedure(
139156
ProtocolVersion.V3.versioned_procedure_name("gds.arrow.write"),
140157
parameters,

graphdatascience/query_runner/session_query_runner.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from graphdatascience.query_runner.graph_constructor import GraphConstructor
1010
from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger
11+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1112
from graphdatascience.server_version.server_version import ServerVersion
1213

1314
from ..call_parameters import CallParameters
@@ -69,10 +70,12 @@ def call_procedure(
6970
params = CallParameters()
7071

7172
if SessionQueryRunner.GDS_REMOTE_PROJECTION_PROC_NAME in endpoint:
72-
return self._remote_projection(endpoint, params, yields, database, logging)
73+
terminationFlag = TerminationFlag.create()
74+
return self._remote_projection(endpoint, params, terminationFlag, yields, database, logging)
7375

7476
elif ".write" in endpoint and self.is_remote_projected_graph(params["graph_name"]):
75-
return self._remote_write_back(endpoint, params, yields, database, logging, custom_error)
77+
terminationFlag = TerminationFlag.create()
78+
return self._remote_write_back(endpoint, params, terminationFlag, yields, database, logging, custom_error)
7679

7780
return self._gds_query_runner.call_procedure(endpoint, params, yields, database, logging, custom_error)
7881

@@ -126,6 +129,7 @@ def _remote_projection(
126129
self,
127130
endpoint: str,
128131
params: CallParameters,
132+
terminationFlag: TerminationFlag,
129133
yields: Optional[list[str]] = None,
130134
database: Optional[str] = None,
131135
logging: bool = False,
@@ -144,7 +148,7 @@ def _remote_projection(
144148

145149
def run_projection() -> DataFrame:
146150
return project_protocol.run_projection(
147-
self._db_query_runner, endpoint, project_params, yields, database, logging
151+
self._db_query_runner, endpoint, project_params, terminationFlag, yields, database, logging
148152
)
149153

150154
if self._resolve_show_progress(logging):
@@ -159,6 +163,7 @@ def _remote_write_back(
159163
self,
160164
endpoint: str,
161165
params: CallParameters,
166+
terminationFlag: TerminationFlag,
162167
yields: Optional[list[str]] = None,
163168
database: Optional[str] = None,
164169
logging: bool = False,
@@ -180,6 +185,7 @@ def _remote_write_back(
180185
gds_write_result = self._gds_query_runner.call_procedure(
181186
endpoint, params, yields, database, logging, custom_error
182187
)
188+
terminationFlag.assert_running()
183189

184190
self._inject_arrow_config(db_arrow_config)
185191

@@ -193,7 +199,7 @@ def _remote_write_back(
193199
write_back_start = time.time()
194200

195201
def run_write_back() -> DataFrame:
196-
return write_protocol.run_write_back(self._db_query_runner, write_back_params, yields)
202+
return write_protocol.run_write_back(self._db_query_runner, write_back_params, yields, terminationFlag)
197203

198204
try:
199205
if self._resolve_show_progress(logging):
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
import signal
5+
import threading
6+
from abc import ABC, abstractmethod
7+
from types import FrameType
8+
from typing import Optional
9+
10+
11+
class TerminationFlag(ABC):
12+
@abstractmethod
13+
def is_set(self) -> bool:
14+
pass
15+
16+
@abstractmethod
17+
def set(self) -> None:
18+
pass
19+
20+
def assert_running(self) -> None:
21+
if self.is_set():
22+
raise RuntimeError("Query has been terminated")
23+
24+
@staticmethod
25+
def create(signals: Optional[list[signal.Signals]] = None) -> TerminationFlag:
26+
if signals is None:
27+
signals = [signal.SIGINT, signal.SIGTERM]
28+
29+
if threading.current_thread() == threading.main_thread():
30+
return TerminationFlagImpl(signals)
31+
else:
32+
logging.debug("Cannot set terminationFlag for query runner in non-main thread")
33+
return TerminationFlagNoop()
34+
35+
36+
class TerminationFlagImpl(TerminationFlag):
37+
def __init__(self, signals: list[signal.Signals]) -> None:
38+
self._event = threading.Event()
39+
40+
def receive_signal(sig: int, frame: Optional[FrameType]) -> None:
41+
logging.debug(f"Received signal {sig}. Interrupting query.")
42+
self._event.set()
43+
44+
for sig in signals:
45+
signal.signal(sig, receive_signal)
46+
47+
def is_set(self) -> bool:
48+
return self._event.is_set()
49+
50+
def set(self) -> None:
51+
self._event.set()
52+
53+
def assert_running(self) -> None:
54+
if self.is_set():
55+
raise RuntimeError("Query has been terminated")
56+
57+
58+
class TerminationFlagNoop(TerminationFlag):
59+
def __init__(self) -> None:
60+
pass
61+
62+
def is_set(self) -> bool:
63+
return False
64+
65+
def set(self) -> None:
66+
pass
67+
68+
def assert_running(self) -> None:
69+
pass

0 commit comments

Comments
 (0)