Skip to content

Commit a6f18ce

Browse files
committed
Allow interruption during wait for job
1 parent 91f3f28 commit a6f18ce

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

graphdatascience/arrow_client/v2/job_client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33

44
from pandas import ArrowDtype, DataFrame
55
from pyarrow._flight import Ticket
6-
from tenacity import Retrying, retry_always, wait_exponential
6+
from tenacity import Retrying, retry_if_result, wait_exponential
77

88
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
99
from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus
1010
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
1111
from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar
12+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1213

1314
JOB_STATUS_ENDPOINT = "v2/jobs.status"
1415
RESULTS_SUMMARY_ENDPOINT = "v2/results.summary"
@@ -33,11 +34,22 @@ def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, A
3334
single = deserialize_single(res)
3435
return JobIdConfig(**single).job_id
3536

36-
def wait_for_job(self, client: AuthenticatedArrowClient, job_id: str, show_progress: bool) -> None:
37+
def wait_for_job(
38+
self,
39+
client: AuthenticatedArrowClient,
40+
job_id: str,
41+
show_progress: bool,
42+
termination_flag: Optional[TerminationFlag] = None,
43+
) -> None:
3744
progress_bar: Optional[TqdmProgressBar] = None
3845

39-
for attempt in Retrying(retry=retry_always, wait=wait_exponential(min=0.1, max=5)):
46+
if termination_flag is None:
47+
termination_flag = TerminationFlag.create()
48+
49+
for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5)):
4050
with attempt:
51+
termination_flag.assert_running()
52+
4153
arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel())
4254
job_status = JobStatus(**deserialize_single(arrow_res))
4355

graphdatascience/arrow_client/v2/remote_write_back_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from graphdatascience.call_parameters import CallParameters
99
from graphdatascience.procedure_surface.api.base_result import BaseResult
1010
from graphdatascience.query_runner.protocol.write_protocols import WriteProtocol
11-
from graphdatascience.query_runner.termination_flag import TerminationFlagNoop
11+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1212
from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver
1313

1414

@@ -53,7 +53,7 @@ def write(
5353
write_back_params,
5454
None,
5555
log_progress=log_progress,
56-
terminationFlag=TerminationFlagNoop(),
56+
terminationFlag=TerminationFlag.create(),
5757
).squeeze()
5858
write_millis = int((time.time() - start_time) * 1000)
5959

graphdatascience/tests/unit/arrow_client/V2/test_job_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from io import StringIO
22

3+
import pytest
34
from pytest_mock import MockerFixture
45

56
from graphdatascience.arrow_client.v2.api_types import UNKNOWN_PROGRESS, JobIdConfig, JobStatus
67
from graphdatascience.arrow_client.v2.job_client import JobClient
8+
from graphdatascience.query_runner.termination_flag import TerminationFlag
79
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult
810

911

@@ -126,6 +128,22 @@ def test_wait_for_job_waits_for_aborted(mocker: MockerFixture) -> None:
126128
assert mock_client.do_action_with_retry.call_count == 2
127129

128130

131+
def test_wait_for_job_stops_on_interrupt(mocker: MockerFixture) -> None:
132+
mock_client = mocker.Mock()
133+
job_id = "test-job-waiting"
134+
mock_client.do_action_with_retry = mocker.Mock()
135+
136+
termination_flag = TerminationFlag.create()
137+
termination_flag.set()
138+
139+
with pytest.raises(
140+
RuntimeError, match="Closing client connection. Note, the query will be continued on the server-side"
141+
):
142+
JobClient().wait_for_job(mock_client, job_id, show_progress=False, termination_flag=termination_flag)
143+
144+
assert mock_client.do_action_with_retry.call_count == 0
145+
146+
129147
def test_wait_for_job_progress_bar_quantive(mocker: MockerFixture) -> None:
130148
mock_client = mocker.Mock()
131149
job_id = "test-job-progress"

0 commit comments

Comments
 (0)