Skip to content

Commit cb7b12f

Browse files
authored
Merge pull request #974 from FlorentinD/v2-terminationflag-GDSA-208
Allow interruption during wait for job
2 parents a38d88a + a6f18ce commit cb7b12f

File tree

3 files changed

+58
-25
lines changed

3 files changed

+58
-25
lines changed

graphdatascience/arrow_client/v2/job_client.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
from pandas import ArrowDtype, DataFrame
55
from pyarrow._flight import Ticket
6+
from tenacity import Retrying, retry_if_result, wait_exponential
67

78
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
89
from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus
910
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
1011
from graphdatascience.query_runner.progress.progress_bar import TqdmProgressBar
12+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1113

1214
JOB_STATUS_ENDPOINT = "v2/jobs.status"
1315
RESULTS_SUMMARY_ENDPOINT = "v2/results.summary"
@@ -32,28 +34,41 @@ def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, A
3234
single = deserialize_single(res)
3335
return JobIdConfig(**single).job_id
3436

35-
def wait_for_job(self, client: AuthenticatedArrowClient, job_id: str, show_progress: bool) -> None:
37+
def wait_for_job(
38+
self,
39+
client: AuthenticatedArrowClient,
40+
job_id: str,
41+
show_progress: bool,
42+
termination_flag: Optional[TerminationFlag] = None,
43+
) -> None:
3644
progress_bar: Optional[TqdmProgressBar] = None
37-
while True:
38-
arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel())
39-
job_status = JobStatus(**deserialize_single(arrow_res))
40-
41-
if job_status.succeeded() or job_status.aborted():
42-
if progress_bar:
43-
progress_bar.finish(success=job_status.succeeded())
44-
break
45-
46-
if show_progress:
47-
if progress_bar is None:
48-
base_task = job_status.base_task()
49-
if base_task:
50-
progress_bar = TqdmProgressBar(
51-
task_name=base_task,
52-
relative_progress=job_status.progress_percent(),
53-
bar_options=self._progress_bar_options,
54-
)
55-
if progress_bar:
56-
progress_bar.update(job_status.status, job_status.progress_percent(), job_status.sub_tasks())
45+
46+
if termination_flag is None:
47+
termination_flag = TerminationFlag.create()
48+
49+
for attempt in Retrying(retry=retry_if_result(lambda _: True), wait=wait_exponential(min=0.1, max=5)):
50+
with attempt:
51+
termination_flag.assert_running()
52+
53+
arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel())
54+
job_status = JobStatus(**deserialize_single(arrow_res))
55+
56+
if job_status.succeeded() or job_status.aborted():
57+
if progress_bar:
58+
progress_bar.finish(success=job_status.succeeded())
59+
return
60+
61+
if show_progress:
62+
if progress_bar is None:
63+
base_task = job_status.base_task()
64+
if base_task:
65+
progress_bar = TqdmProgressBar(
66+
task_name=base_task,
67+
relative_progress=job_status.progress_percent(),
68+
bar_options=self._progress_bar_options,
69+
)
70+
if progress_bar:
71+
progress_bar.update(job_status.status, job_status.progress_percent(), job_status.sub_tasks())
5772

5873
@staticmethod
5974
def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]:

graphdatascience/arrow_client/v2/remote_write_back_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from graphdatascience.call_parameters import CallParameters
99
from graphdatascience.procedure_surface.api.base_result import BaseResult
1010
from graphdatascience.query_runner.protocol.write_protocols import WriteProtocol
11-
from graphdatascience.query_runner.termination_flag import TerminationFlagNoop
11+
from graphdatascience.query_runner.termination_flag import TerminationFlag
1212
from graphdatascience.session.dbms.protocol_resolver import ProtocolVersionResolver
1313

1414

@@ -53,7 +53,7 @@ def write(
5353
write_back_params,
5454
None,
5555
log_progress=log_progress,
56-
terminationFlag=TerminationFlagNoop(),
56+
terminationFlag=TerminationFlag.create(),
5757
).squeeze()
5858
write_millis = int((time.time() - start_time) * 1000)
5959

graphdatascience/tests/unit/arrow_client/V2/test_job_client.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from io import StringIO
22

3+
import pytest
34
from pytest_mock import MockerFixture
45

56
from graphdatascience.arrow_client.v2.api_types import UNKNOWN_PROGRESS, JobIdConfig, JobStatus
67
from graphdatascience.arrow_client.v2.job_client import JobClient
8+
from graphdatascience.query_runner.termination_flag import TerminationFlag
79
from graphdatascience.tests.unit.arrow_client.arrow_test_utils import ArrowTestResult
810

911

@@ -126,6 +128,22 @@ def test_wait_for_job_waits_for_aborted(mocker: MockerFixture) -> None:
126128
assert mock_client.do_action_with_retry.call_count == 2
127129

128130

131+
def test_wait_for_job_stops_on_interrupt(mocker: MockerFixture) -> None:
132+
mock_client = mocker.Mock()
133+
job_id = "test-job-waiting"
134+
mock_client.do_action_with_retry = mocker.Mock()
135+
136+
termination_flag = TerminationFlag.create()
137+
termination_flag.set()
138+
139+
with pytest.raises(
140+
RuntimeError, match="Closing client connection. Note, the query will be continued on the server-side"
141+
):
142+
JobClient().wait_for_job(mock_client, job_id, show_progress=False, termination_flag=termination_flag)
143+
144+
assert mock_client.do_action_with_retry.call_count == 0
145+
146+
129147
def test_wait_for_job_progress_bar_quantive(mocker: MockerFixture) -> None:
130148
mock_client = mocker.Mock()
131149
job_id = "test-job-progress"
@@ -173,8 +191,8 @@ def test_wait_for_job_progress_bar_qualitative(mocker: MockerFixture) -> None:
173191

174192
progress_output = pbarOutputStream.getvalue().split("\r")
175193
assert "Algo [elapsed: 00:00 ]" in progress_output
176-
assert "Algo [elapsed: 00:00 , status: RUNNING, task: Halfway there]" in progress_output
177-
assert any("Algo [elapsed: 00:00 , status: FINISHED]" in line for line in progress_output)
194+
assert "Algo [elapsed: 00:01 , status: RUNNING, task: Halfway there]" in progress_output
195+
assert any("Algo [elapsed: 00:03 , status: FINISHED]" in line for line in progress_output)
178196

179197

180198
def test_get_summary(mocker: MockerFixture) -> None:

0 commit comments

Comments
 (0)