1111import logging
1212import os
1313import re
14+ import sys
1415import time
1516import traceback
1617from datetime import date , datetime
1718from typing import Callable , Optional , Union
1819from enum import Enum
1920
2021import oci
22+ import tqdm
2123import yaml
2224from ads .common import auth
2325from ads .common .decorator .utils import class_or_instance_method
@@ -1038,3 +1040,95 @@ def from_name(cls, name: str, compartment_id: Optional[str] = None):
10381040 if not res :
10391041 raise OCIModelNotExists ()
10401042 return cls .from_oci_model (res [0 ])
1043+
1044+
1045+ class ADSWorkRequest (OCIClientMixin ):
1046+
1047+ def __init__ (self , id : str , description : str = "Processing" ):
1048+ self .id = id
1049+ self ._description = description
1050+ self ._percentage = 0
1051+ self ._status = None
1052+
1053+ def _sync (self ):
1054+ try :
1055+ work_request = self .client .get_work_request (self .id ).data
1056+ work_request_logs = self .client .list_work_request_logs (
1057+ self .id
1058+ ).data
1059+
1060+ self ._percentage = work_request .percent_complete
1061+ self ._status = work_request .status
1062+ self ._description = work_request_logs [:- 1 ]
1063+ except Exception as ex :
1064+ logger .warn (ex )
1065+
1066+ def watch (
1067+ self ,
1068+ progress_callback : Callable ,
1069+ max_wait_time : int ,
1070+ poll_interval : int ,
1071+ ):
1072+ previous_percent_complete = 0
1073+ previous_log = None
1074+
1075+ start_time = time .time ()
1076+ while self ._percentage < 100 :
1077+
1078+ seconds_since = time .time () - start_time
1079+ if max_wait_time > 0 and seconds_since >= max_wait_time :
1080+ logger .error (f"Max wait time ({ max_wait_time } seconds) exceeded." )
1081+ return
1082+
1083+ time .sleep (poll_interval )
1084+ self ._sync ()
1085+ percent_change = self ._percentage - previous_percent_complete
1086+ previous_percent_complete = self ._percentage
1087+ description = self ._description if previous_log != self ._description else ""
1088+ progress_callback (
1089+ percent_change = percent_change ,
1090+ description = description
1091+ )
1092+ previous_log = self ._description
1093+
1094+ if self ._status in WORK_REQUEST_STOP_STATE :
1095+ if self ._status != oci .work_requests .models .WorkRequest .STATUS_SUCCEEDED :
1096+ if self ._description :
1097+ raise Exception (self ._description )
1098+ else :
1099+ raise Exception (
1100+ "Error occurred in attempt to perform the operation. "
1101+ "Check the service logs to get more details. "
1102+ )
1103+ else :
1104+ break
1105+
1106+ progress_callback (percent_change = 0 , description = "Done" )
1107+
1108+
1109+ def wait_work_request (
1110+ id : str ,
1111+ desc : str ,
1112+ max_wait_time : int = DEFAULT_WAIT_TIME ,
1113+ poll_interval : int = DEFAULT_POLL_INTERVAL
1114+ ):
1115+ ads_work_request = ADSWorkRequest (id )
1116+
1117+ with tqdm (
1118+ leave = False ,
1119+ file = sys .stdout ,
1120+ desc = desc ,
1121+ ) as pbar :
1122+
1123+ def progress_callback (percent_change , description ):
1124+ if percent_change != 0 :
1125+ pbar .update (percent_change )
1126+ if description :
1127+ pbar .set_description (description )
1128+
1129+ ads_work_request .watch (
1130+ progress_callback ,
1131+ max_wait_time = max_wait_time ,
1132+ poll_interval = poll_interval
1133+ )
1134+
0 commit comments