|
3 | 3 |
|
4 | 4 | from pandas import ArrowDtype, DataFrame |
5 | 5 | from pyarrow._flight import Ticket |
| 6 | +from tenacity import Retrying, retry_always, wait_exponential |
6 | 7 |
|
7 | 8 | from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient |
8 | 9 | from graphdatascience.arrow_client.v2.api_types import JobIdConfig, JobStatus |
@@ -34,26 +35,28 @@ def run_job(client: AuthenticatedArrowClient, endpoint: str, config: dict[str, A |
34 | 35 |
|
35 | 36 | def wait_for_job(self, client: AuthenticatedArrowClient, job_id: str, show_progress: bool) -> None: |
36 | 37 | progress_bar: Optional[TqdmProgressBar] = None |
37 | | - while True: |
38 | | - arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel()) |
39 | | - job_status = JobStatus(**deserialize_single(arrow_res)) |
40 | | - |
41 | | - if job_status.succeeded() or job_status.aborted(): |
42 | | - if progress_bar: |
43 | | - progress_bar.finish(success=job_status.succeeded()) |
44 | | - break |
45 | | - |
46 | | - if show_progress: |
47 | | - if progress_bar is None: |
48 | | - base_task = job_status.base_task() |
49 | | - if base_task: |
50 | | - progress_bar = TqdmProgressBar( |
51 | | - task_name=base_task, |
52 | | - relative_progress=job_status.progress_percent(), |
53 | | - bar_options=self._progress_bar_options, |
54 | | - ) |
55 | | - if progress_bar: |
56 | | - progress_bar.update(job_status.status, job_status.progress_percent(), job_status.sub_tasks()) |
| 38 | + |
| 39 | + for attempt in Retrying(retry=retry_always, wait=wait_exponential(min=0.1, max=5)): |
| 40 | + with attempt: |
| 41 | + arrow_res = client.do_action_with_retry(JOB_STATUS_ENDPOINT, JobIdConfig(jobId=job_id).dump_camel()) |
| 42 | + job_status = JobStatus(**deserialize_single(arrow_res)) |
| 43 | + |
| 44 | + if job_status.succeeded() or job_status.aborted(): |
| 45 | + if progress_bar: |
| 46 | + progress_bar.finish(success=job_status.succeeded()) |
| 47 | + return |
| 48 | + |
| 49 | + if show_progress: |
| 50 | + if progress_bar is None: |
| 51 | + base_task = job_status.base_task() |
| 52 | + if base_task: |
| 53 | + progress_bar = TqdmProgressBar( |
| 54 | + task_name=base_task, |
| 55 | + relative_progress=job_status.progress_percent(), |
| 56 | + bar_options=self._progress_bar_options, |
| 57 | + ) |
| 58 | + if progress_bar: |
| 59 | + progress_bar.update(job_status.status, job_status.progress_percent(), job_status.sub_tasks()) |
57 | 60 |
|
58 | 61 | @staticmethod |
59 | 62 | def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]: |
|
0 commit comments