11#!/usr/bin/env python
2- # -*- coding: utf-8; -*-
3-
4- # Copyright (c) 2024 Oracle and/or its affiliates.
2+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
53# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
64
75import logging
1210import oci
1311from oci import Signer
1412from tqdm .auto import tqdm
13+
1514from ads .common .oci_datascience import OCIDataScienceMixin
1615
1716logger = logging .getLogger (__name__ )
2019DEFAULT_WAIT_TIME = 1200
2120DEFAULT_POLL_INTERVAL = 10
2221WORK_REQUEST_PERCENTAGE = 100
23- # default tqdm progress bar format:
22+ # default tqdm progress bar format:
2423# {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
2524# customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26- DEFAULT_BAR_FORMAT = ' {l_bar}{bar}| [{elapsed}<{remaining}, ' ' {rate_fmt}{postfix}]'
25+ DEFAULT_BAR_FORMAT = " {l_bar}{bar}| [{elapsed}<{remaining}, " " {rate_fmt}{postfix}]"
2726
2827
2928class DataScienceWorkRequest (OCIDataScienceMixin ):
@@ -32,13 +31,13 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
3231 """
3332
3433 def __init__ (
35- self ,
36- id : str ,
34+ self ,
35+ id : str ,
3736 description : str = "Processing" ,
38- config : dict = None ,
39- signer : Signer = None ,
37+ config : dict = None ,
38+ signer : Signer = None ,
4039 client_kwargs : dict = None ,
41- ** kwargs
40+ ** kwargs ,
4241 ) -> None :
4342 """Initializes ADSWorkRequest object.
4443
@@ -49,43 +48,43 @@ def __init__(
4948 description: str
5049 Progress bar initial step description (Defaults to `Processing`).
5150 config : dict, optional
52- OCI API key config dictionary to initialize
51+ OCI API key config dictionary to initialize
5352 oci.data_science.DataScienceClient (Defaults to None).
5453 signer : oci.signer.Signer, optional
55- OCI authentication signer to initialize
54+ OCI authentication signer to initialize
5655 oci.data_science.DataScienceClient (Defaults to None).
5756 client_kwargs : dict, optional
58- Additional client keyword arguments to initialize
57+ Additional client keyword arguments to initialize
5958 oci.data_science.DataScienceClient (Defaults to None).
6059 kwargs:
61- Additional keyword arguments to initialize
60+ Additional keyword arguments to initialize
6261 oci.data_science.DataScienceClient.
6362 """
6463 self .id = id
6564 self ._description = description
6665 self ._percentage = 0
6766 self ._status = None
68- _error_message = None
67+ self . _error_message = ""
6968 super ().__init__ (config , signer , client_kwargs , ** kwargs )
70-
7169
7270 def _sync (self ):
7371 """Fetches the latest work request information to ADSWorkRequest object."""
7472 work_request = self .client .get_work_request (self .id ).data
75- work_request_logs = self .client .list_work_request_logs (
76- self .id
77- ).data
73+ work_request_logs = self .client .list_work_request_logs (self .id ).data
7874
79- self ._percentage = work_request .percent_complete
75+ self ._percentage = work_request .percent_complete
8076 self ._status = work_request .status
81- self ._description = work_request_logs [- 1 ].message if work_request_logs else "Processing"
82- if work_request .status == 'FAILED' : self ._error_message = self .client .list_work_request_errors
77+ self ._description = (
78+ work_request_logs [- 1 ].message if work_request_logs else "Processing"
79+ )
80+ if work_request .status == "FAILED" :
81+ self ._error_message = self .client .list_work_request_errors (self .id ).data
8382
8483 def watch (
85- self ,
84+ self ,
8685 progress_callback : Callable ,
87- max_wait_time : int = DEFAULT_WAIT_TIME ,
88- poll_interval : int = DEFAULT_POLL_INTERVAL ,
86+ max_wait_time : int = DEFAULT_WAIT_TIME ,
87+ poll_interval : int = DEFAULT_POLL_INTERVAL ,
8988 ):
9089 """Updates the progress bar with realtime message and percentage until the process is completed.
9190
@@ -94,10 +93,10 @@ def watch(
9493 progress_callback: Callable
9594 Progress bar callback function.
9695 It must accept `(percent_change, description)` where `percent_change` is the
97- work request percent complete and `description` is the latest work request log message.
96+ work request percent complete and `description` is the latest work request log message.
9897 max_wait_time: int
9998 Maximum amount of time to wait in seconds (Defaults to 1200).
100- Negative implies infinite wait time.
99+ Negative implies infinite wait time.
101100 poll_interval: int
102101 Poll interval in seconds (Defaults to 10).
103102
@@ -109,7 +108,6 @@ def watch(
109108
110109 start_time = time .time ()
111110 while self ._percentage < 100 :
112-
113111 seconds_since = time .time () - start_time
114112 if max_wait_time > 0 and seconds_since >= max_wait_time :
115113 logger .error (f"Exceeded max wait time of { max_wait_time } seconds." )
@@ -126,12 +124,14 @@ def watch(
126124 percent_change = self ._percentage - previous_percent_complete
127125 previous_percent_complete = self ._percentage
128126 progress_callback (
129- percent_change = percent_change ,
130- description = self ._description
127+ percent_change = percent_change , description = self ._description
131128 )
132129
133130 if self ._status in WORK_REQUEST_STOP_STATE :
134- if self ._status != oci .work_requests .models .WorkRequest .STATUS_SUCCEEDED :
131+ if (
132+ self ._status
133+ != oci .work_requests .models .WorkRequest .STATUS_SUCCEEDED
134+ ):
135135 if self ._description :
136136 raise Exception (self ._description )
137137 else :
@@ -147,12 +147,12 @@ def watch(
147147
148148 def wait_work_request (
149149 self ,
150- progress_bar_description : str = "Processing" ,
151- max_wait_time : int = DEFAULT_WAIT_TIME ,
152- poll_interval : int = DEFAULT_POLL_INTERVAL
150+ progress_bar_description : str = "Processing" ,
151+ max_wait_time : int = DEFAULT_WAIT_TIME ,
152+ poll_interval : int = DEFAULT_POLL_INTERVAL ,
153153 ):
154154 """Waits for the work request progress bar to be completed.
155-
155+
156156 Parameters
157157 ----------
158158 progress_bar_description: str
@@ -162,7 +162,7 @@ def wait_work_request(
162162 Negative implies infinite wait time.
163163 poll_interval: int
164164 Poll interval in seconds (Defaults to 10).
165-
165+
166166 Returns
167167 -------
168168 None
@@ -174,7 +174,7 @@ def wait_work_request(
174174 mininterval = 0 ,
175175 file = sys .stdout ,
176176 desc = progress_bar_description ,
177- bar_format = DEFAULT_BAR_FORMAT
177+ bar_format = DEFAULT_BAR_FORMAT ,
178178 ) as pbar :
179179
180180 def progress_callback (percent_change , description ):
@@ -186,6 +186,5 @@ def progress_callback(percent_change, description):
186186 self .watch (
187187 progress_callback = progress_callback ,
188188 max_wait_time = max_wait_time ,
189- poll_interval = poll_interval
189+ poll_interval = poll_interval ,
190190 )
191-
0 commit comments