Skip to content

Commit 5477064

Browse files
committed
Check for signals during projection and write-back
1 parent 0f8643f commit 5477064

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

graphdatascience/query_runner/protocol/project_protocols.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import signal
12
from abc import ABC, abstractmethod
23
from typing import Any, Optional
34

@@ -6,6 +7,7 @@
67

78
from graphdatascience import QueryRunner
89
from graphdatascience.call_parameters import CallParameters
10+
from graphdatascience.query_runner.protocol.retry_utils import retry_unless_signal
911
from graphdatascience.query_runner.protocol.status import Status
1012
from graphdatascience.session.dbms.protocol_version import ProtocolVersion
1113

@@ -123,7 +125,10 @@ def is_not_done(result: DataFrame) -> bool:
123125
status: str = result.squeeze()["status"]
124126
return status != Status.DONE.name
125127

126-
@retry(retry=retry_if_result(is_not_done), wait=wait_incrementing(start=0.2, increment=0.2, max=2))
128+
@retry(
129+
retry=retry_if_result(is_not_done) and retry_unless_signal([signal.SIGTERM, signal.SIGINT]),
130+
wait=wait_incrementing(start=0.2, increment=0.2, max=2),
131+
)
127132
def project_fn() -> DataFrame:
128133
return query_runner.call_procedure(
129134
ProtocolVersion.V3.versioned_procedure_name(endpoint), params, yields, database, logging, False
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import signal
2+
from types import FrameType
3+
from typing import Optional
4+
5+
from tenacity import RetryCallState, retry_base
6+
7+
8+
class retry_unless_signal(retry_base):
9+
"""Retries unless one of the given signals is raised."""
10+
11+
def __init__(self, signals: list[signal.Signals]) -> None:
12+
self.signal_received = False
13+
14+
def receive_signal(sig: int, frame: Optional[FrameType]) -> None:
15+
self.signal_received = True
16+
17+
for sig in signals:
18+
signal.signal(sig, receive_signal)
19+
20+
def __call__(self, retry_state: RetryCallState) -> bool:
21+
return not self.signal_received

graphdatascience/query_runner/protocol/write_protocols.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import logging
2+
import signal
13
from abc import ABC, abstractmethod
24
from typing import Any, Optional
35

46
from pandas import DataFrame
5-
from tenacity import retry, retry_if_result, wait_incrementing
7+
from tenacity import after_log, retry, retry_if_result, wait_incrementing
68

79
from graphdatascience import QueryRunner
810
from graphdatascience.call_parameters import CallParameters
11+
from graphdatascience.query_runner.protocol.retry_utils import retry_unless_signal
912
from graphdatascience.query_runner.protocol.status import Status
1013
from graphdatascience.session.dbms.protocol_version import ProtocolVersion
1114

@@ -120,7 +123,13 @@ def is_not_completed(result: DataFrame) -> bool:
120123
status: str = result.squeeze()["status"]
121124
return status != Status.COMPLETED.name
122125

123-
@retry(retry=retry_if_result(is_not_completed), wait=wait_incrementing(start=0.2, increment=0.2, max=2))
126+
logger = logging.getLogger()
127+
128+
@retry(
129+
retry=retry_if_result(is_not_completed) and retry_unless_signal([signal.SIGTERM, signal.SIGINT]),
130+
wait=wait_incrementing(start=0.2, increment=0.2, max=2),
131+
after=after_log(logger, logging.WARN),
132+
)
124133
def write_fn() -> DataFrame:
125134
return query_runner.call_procedure(
126135
ProtocolVersion.V3.versioned_procedure_name("gds.arrow.write"),

0 commit comments

Comments
 (0)