1+ #!/usr/bin/env python
2+ # -*- coding: utf-8; -*-
3+
4+ # Copyright (c) 2024 Oracle and/or its affiliates.
5+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+ import logging
8+ import sys
9+ import time
10+ from typing import Callable
11+
12+ import oci
13+ from oci import Signer
14+ from tqdm .auto import tqdm
15+ from ads .common .oci_datascience import OCIDataScienceMixin
16+
17+ logger = logging .getLogger (__name__ )
18+
19+ WORK_REQUEST_STOP_STATE = ("SUCCEEDED" , "FAILED" , "CANCELED" )
20+ DEFAULT_WAIT_TIME = 1200
21+ DEFAULT_POLL_INTERVAL = 10
22+ WORK_REQUEST_PERCENTAGE = 100
23+ # default tqdm progress bar format:
24+ # {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
25+ # customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26+ DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
27+
28+
29+ class DataScienceWorkRequest (OCIDataScienceMixin ):
30+ """Class for monitoring OCI WorkRequest and representing on tqdm progress bar. This class inherits
31+ `OCIDataScienceMixin` so as to call its `client` attribute to interact with OCI backend.
32+ """
33+
34+ def __init__ (
35+ self ,
36+ id : str ,
37+ description : str = "Processing" ,
38+ config : dict = None ,
39+ signer : Signer = None ,
40+ client_kwargs : dict = None ,
41+ ** kwargs
42+ ) -> None :
43+ """Initializes ADSWorkRequest object.
44+
45+ Parameters
46+ ----------
47+ id: str
48+ Work Request OCID.
49+ description: str
50+ Progress bar initial step description (Defaults to `Processing`).
51+ config : dict, optional
52+ OCI API key config dictionary to initialize
53+ oci.data_science.DataScienceClient (Defaults to None).
54+ signer : oci.signer.Signer, optional
55+ OCI authentication signer to initialize
56+ oci.data_science.DataScienceClient (Defaults to None).
57+ client_kwargs : dict, optional
58+ Additional client keyword arguments to initialize
59+ oci.data_science.DataScienceClient (Defaults to None).
60+ kwargs:
61+ Additional keyword arguments to initialize
62+ oci.data_science.DataScienceClient.
63+ """
64+ self .id = id
65+ self ._description = description
66+ self ._percentage = 0
67+ self ._status = None
68+ super ().__init__ (config , signer , client_kwargs , ** kwargs )
69+
70+
71+ def _sync (self ):
72+ """Fetches the latest work request information to ADSWorkRequest object."""
73+ work_request = self .client .get_work_request (self .id ).data
74+ work_request_logs = self .client .list_work_request_logs (
75+ self .id
76+ ).data
77+
78+ self ._percentage = work_request .percent_complete
79+ self ._status = work_request .status
80+ self ._description = work_request_logs [- 1 ].message if work_request_logs else "Processing"
81+
82+ def watch (
83+ self ,
84+ progress_callback : Callable ,
85+ max_wait_time : int = DEFAULT_WAIT_TIME ,
86+ poll_interval : int = DEFAULT_POLL_INTERVAL ,
87+ ):
88+ """Updates the progress bar with realtime message and percentage until the process is completed.
89+
90+ Parameters
91+ ----------
92+ progress_callback: Callable
93+ Progress bar callback function.
94+ It must accept `(percent_change, description)` where `percent_change` is the
95+ work request percent complete and `description` is the latest work request log message.
96+ max_wait_time: int
97+ Maximum amount of time to wait in seconds (Defaults to 1200).
98+ Negative implies infinite wait time.
99+ poll_interval: int
100+ Poll interval in seconds (Defaults to 10).
101+
102+ Returns
103+ -------
104+ None
105+ """
106+ previous_percent_complete = 0
107+
108+ start_time = time .time ()
109+ while self ._percentage < 100 :
110+
111+ seconds_since = time .time () - start_time
112+ if max_wait_time > 0 and seconds_since >= max_wait_time :
113+ logger .error (f"Exceeded max wait time of { max_wait_time } seconds." )
114+ return
115+
116+ time .sleep (poll_interval )
117+
118+ try :
119+ self ._sync ()
120+ except Exception as ex :
121+ logger .warn (ex )
122+ continue
123+
124+ percent_change = self ._percentage - previous_percent_complete
125+ previous_percent_complete = self ._percentage
126+ progress_callback (
127+ percent_change = percent_change ,
128+ description = self ._description
129+ )
130+
131+ if self ._status in WORK_REQUEST_STOP_STATE :
132+ if self ._status != oci .work_requests .models .WorkRequest .STATUS_SUCCEEDED :
133+ if self ._description :
134+ raise Exception (self ._description )
135+ else :
136+ raise Exception (
137+ "Error occurred in attempt to perform the operation. "
138+ "Check the service logs to get more details. "
139+ f"Work request id: { self .id } ."
140+ )
141+ else :
142+ break
143+
144+ progress_callback (percent_change = 0 , description = "Done" )
145+
146+ def wait_work_request (
147+ self ,
148+ progress_bar_description : str = "Processing" ,
149+ max_wait_time : int = DEFAULT_WAIT_TIME ,
150+ poll_interval : int = DEFAULT_POLL_INTERVAL
151+ ):
152+ """Waits for the work request progress bar to be completed.
153+
154+ Parameters
155+ ----------
156+ progress_bar_description: str
157+ Progress bar initial step description (Defaults to `Processing`).
158+ max_wait_time: int
159+ Maximum amount of time to wait in seconds (Defaults to 1200).
160+ Negative implies infinite wait time.
161+ poll_interval: int
162+ Poll interval in seconds (Defaults to 10).
163+
164+ Returns
165+ -------
166+ None
167+ """
168+
169+ with tqdm (
170+ total = WORK_REQUEST_PERCENTAGE ,
171+ leave = False ,
172+ mininterval = 0 ,
173+ file = sys .stdout ,
174+ desc = progress_bar_description ,
175+ bar_format = DEFAULT_BAR_FORMAT
176+ ) as pbar :
177+
178+ def progress_callback (percent_change , description ):
179+ if percent_change != 0 :
180+ pbar .update (percent_change )
181+ if description :
182+ pbar .set_description (description )
183+
184+ self .watch (
185+ progress_callback = progress_callback ,
186+ max_wait_time = max_wait_time ,
187+ poll_interval = poll_interval
188+ )
189+
0 commit comments