33
44from pandas import ArrowDtype , DataFrame
55from pyarrow ._flight import Ticket
6+ from tenacity import Retrying , retry_if_result , wait_exponential
67
78from graphdatascience .arrow_client .authenticated_flight_client import AuthenticatedArrowClient
89from graphdatascience .arrow_client .v2 .api_types import JobIdConfig , JobStatus
910from graphdatascience .arrow_client .v2 .data_mapper_utils import deserialize_single
1011from graphdatascience .query_runner .progress .progress_bar import TqdmProgressBar
12+ from graphdatascience .query_runner .termination_flag import TerminationFlag
1113
1214JOB_STATUS_ENDPOINT = "v2/jobs.status"
1315RESULTS_SUMMARY_ENDPOINT = "v2/results.summary"
@@ -32,28 +34,41 @@ def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, A
3234 single = deserialize_single (res )
3335 return JobIdConfig (** single ).job_id
3436
35- def wait_for_job (self , client : AuthenticatedArrowClient , job_id : str , show_progress : bool ) -> None :
37+ def wait_for_job (
38+ self ,
39+ client : AuthenticatedArrowClient ,
40+ job_id : str ,
41+ show_progress : bool ,
42+ termination_flag : Optional [TerminationFlag ] = None ,
43+ ) -> None :
3644 progress_bar : Optional [TqdmProgressBar ] = None
37- while True :
38- arrow_res = client .do_action_with_retry (JOB_STATUS_ENDPOINT , JobIdConfig (jobId = job_id ).dump_camel ())
39- job_status = JobStatus (** deserialize_single (arrow_res ))
40-
41- if job_status .succeeded () or job_status .aborted ():
42- if progress_bar :
43- progress_bar .finish (success = job_status .succeeded ())
44- break
45-
46- if show_progress :
47- if progress_bar is None :
48- base_task = job_status .base_task ()
49- if base_task :
50- progress_bar = TqdmProgressBar (
51- task_name = base_task ,
52- relative_progress = job_status .progress_percent (),
53- bar_options = self ._progress_bar_options ,
54- )
55- if progress_bar :
56- progress_bar .update (job_status .status , job_status .progress_percent (), job_status .sub_tasks ())
45+
46+ if termination_flag is None :
47+ termination_flag = TerminationFlag .create ()
48+
49+ for attempt in Retrying (retry = retry_if_result (lambda _ : True ), wait = wait_exponential (min = 0.1 , max = 5 )):
50+ with attempt :
51+ termination_flag .assert_running ()
52+
53+ arrow_res = client .do_action_with_retry (JOB_STATUS_ENDPOINT , JobIdConfig (jobId = job_id ).dump_camel ())
54+ job_status = JobStatus (** deserialize_single (arrow_res ))
55+
56+ if job_status .succeeded () or job_status .aborted ():
57+ if progress_bar :
58+ progress_bar .finish (success = job_status .succeeded ())
59+ return
60+
61+ if show_progress :
62+ if progress_bar is None :
63+ base_task = job_status .base_task ()
64+ if base_task :
65+ progress_bar = TqdmProgressBar (
66+ task_name = base_task ,
67+ relative_progress = job_status .progress_percent (),
68+ bar_options = self ._progress_bar_options ,
69+ )
70+ if progress_bar :
71+ progress_bar .update (job_status .status , job_status .progress_percent (), job_status .sub_tasks ())
5772
5873 @staticmethod
5974 def get_summary (client : AuthenticatedArrowClient , job_id : str ) -> dict [str , Any ]:
0 commit comments