Skip to content

Commit 70c1fd5

Browse files
committed
Merge remote-tracking branch 'origin/main' into feature/feature-store-marketplace-operator
2 parents cbea610 + d1b0ede commit 70c1fd5

File tree

24 files changed

+1156
-579
lines changed

24 files changed

+1156
-579
lines changed

ads/common/oci_datascience.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import os
8+
79
import oci.data_science
8-
from ads.common.oci_mixin import OCIModelMixin
10+
911
from ads.common.decorator.utils import class_or_instance_method
12+
from ads.common.oci_mixin import OCIModelMixin
13+
14+
ENV_VAR_OCI_ODSC_SERVICE_ENDPOINT = "OCI_ODSC_SERVICE_ENDPOINT"
1015

1116

1217
class OCIDataScienceMixin(OCIModelMixin):
1318
@class_or_instance_method
1419
def init_client(cls, **kwargs) -> oci.data_science.DataScienceClient:
20+
client_kwargs = kwargs.get("client_kwargs", {})
21+
if os.environ.get(ENV_VAR_OCI_ODSC_SERVICE_ENDPOINT):
22+
client_kwargs.update(
23+
dict(service_endpoint=os.environ.get(ENV_VAR_OCI_ODSC_SERVICE_ENDPOINT))
24+
)
25+
kwargs.update(client_kwargs)
1526
return cls._init_client(client=oci.data_science.DataScienceClient, **kwargs)
1627

1728
@property

ads/common/oci_mixin.py

Lines changed: 8 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8; -*-
33

4-
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4+
# Copyright (c) 2021, 2024 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
"""Contains Mixins for integrating OCI data models
@@ -11,22 +11,22 @@
1111
import logging
1212
import os
1313
import re
14-
import time
1514
import traceback
1615
from datetime import date, datetime
17-
from typing import Callable, Optional, Union
1816
from enum import Enum
17+
from typing import Callable, Optional, Union
1918

2019
import oci
2120
import yaml
22-
from ads.common import auth
23-
from ads.common.decorator.utils import class_or_instance_method
24-
from ads.common.utils import camel_to_snake, get_progress_bar
25-
from ads.config import COMPARTMENT_OCID
2621
from dateutil import tz
2722
from dateutil.parser import parse
2823
from oci._vendor import six
2924

25+
from ads.common import auth
26+
from ads.common.decorator.utils import class_or_instance_method
27+
from ads.common.utils import camel_to_snake
28+
from ads.config import COMPARTMENT_OCID
29+
3030
logger = logging.getLogger(__name__)
3131

3232
LIFECYCLE_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED", "DELETED")
@@ -274,7 +274,7 @@ def deserialize(cls, data, to_cls):
274274
else:
275275
return cls.__deserialize_model(data, to_cls)
276276

277-
@classmethod
277+
@class_or_instance_method
278278
def __deserialize_model(cls, data, to_cls):
279279
"""De-serializes list or dict to model."""
280280
if isinstance(data, to_cls):
@@ -936,88 +936,6 @@ def get_work_request_response(
936936
)
937937
return work_request_response
938938

939-
def wait_for_progress(
940-
self,
941-
work_request_id: str,
942-
max_wait_time: int = DEFAULT_WAIT_TIME,
943-
poll_interval: int = DEFAULT_POLL_INTERVAL,
944-
):
945-
"""Waits for the work request progress bar to be completed.
946-
947-
Parameters
948-
----------
949-
work_request_id: str
950-
Work Request OCID.
951-
max_wait_time: int
952-
Maximum amount of time to wait in seconds (Defaults to 1200).
953-
Negative implies infinite wait time.
954-
poll_interval: int
955-
Poll interval in seconds (Defaults to 10).
956-
957-
Returns
958-
-------
959-
None
960-
"""
961-
work_request_logs = []
962-
963-
i = 0
964-
start_time = time.time()
965-
with get_progress_bar(WORK_REQUEST_PERCENTAGE) as progress:
966-
seconds_since = time.time() - start_time
967-
exceed_max_time = max_wait_time > 0 and seconds_since >= max_wait_time
968-
if exceed_max_time:
969-
logger.error(f"Max wait time ({max_wait_time} seconds) exceeded.")
970-
previous_percent_complete = 0
971-
while not exceed_max_time and (
972-
not work_request_logs or previous_percent_complete <= WORK_REQUEST_PERCENTAGE
973-
):
974-
time.sleep(poll_interval)
975-
new_work_request_logs = []
976-
977-
try:
978-
work_request = self.client.get_work_request(work_request_id).data
979-
work_request_logs = self.client.list_work_request_logs(
980-
work_request_id
981-
).data
982-
except Exception as ex:
983-
logger.warn(ex)
984-
985-
new_work_request_logs = (
986-
work_request_logs[i:] if work_request_logs else []
987-
)
988-
989-
percent_change = work_request.percent_complete - previous_percent_complete
990-
previous_percent_complete = work_request.percent_complete
991-
992-
if len(new_work_request_logs) > 0:
993-
start_index = True
994-
for wr_item in new_work_request_logs:
995-
if start_index:
996-
progress.update(wr_item.message, percent_change)
997-
start_index = False
998-
else:
999-
progress.update(wr_item.message, 0)
1000-
i += 1
1001-
else:
1002-
# if there is new percent change but the new work request logs is empty
1003-
# needs to add this percent change to the bar to ensure the final percentage is 100
1004-
if percent_change != 0:
1005-
progress.update(n=percent_change)
1006-
1007-
if work_request and work_request.status in WORK_REQUEST_STOP_STATE:
1008-
if work_request.status != "SUCCEEDED":
1009-
if new_work_request_logs:
1010-
raise Exception(new_work_request_logs[-1].message)
1011-
else:
1012-
raise Exception(
1013-
"Error occurred in attempt to perform the operation. "
1014-
"Check the service logs to get more details. "
1015-
f"{work_request}"
1016-
)
1017-
else:
1018-
break
1019-
progress.update("Done")
1020-
1021939

1022940
class OCIModelWithNameMixin:
1023941
"""Mixin class to operate OCI model which contains name property."""

ads/common/work_request.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8; -*-
3+
4+
# Copyright (c) 2024 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import logging
8+
import sys
9+
import time
10+
from typing import Callable
11+
12+
import oci
13+
from oci import Signer
14+
from tqdm.auto import tqdm
15+
from ads.common.oci_datascience import OCIDataScienceMixin
16+
17+
logger = logging.getLogger(__name__)
18+
19+
WORK_REQUEST_STOP_STATE = ("SUCCEEDED", "FAILED", "CANCELED")
20+
DEFAULT_WAIT_TIME = 1200
21+
DEFAULT_POLL_INTERVAL = 10
22+
WORK_REQUEST_PERCENTAGE = 100
23+
# default tqdm progress bar format:
24+
# {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
25+
# customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26+
DEFAULT_BAR_FORMAT = '{l_bar}{bar}| [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]'
27+
28+
29+
class DataScienceWorkRequest(OCIDataScienceMixin):
30+
"""Class for monitoring OCI WorkRequest and representing on tqdm progress bar. This class inherits
31+
`OCIDataScienceMixin` so as to call its `client` attribute to interact with OCI backend.
32+
"""
33+
34+
def __init__(
35+
self,
36+
id: str,
37+
description: str = "Processing",
38+
config: dict = None,
39+
signer: Signer = None,
40+
client_kwargs: dict = None,
41+
**kwargs
42+
) -> None:
43+
"""Initializes ADSWorkRequest object.
44+
45+
Parameters
46+
----------
47+
id: str
48+
Work Request OCID.
49+
description: str
50+
Progress bar initial step description (Defaults to `Processing`).
51+
config : dict, optional
52+
OCI API key config dictionary to initialize
53+
oci.data_science.DataScienceClient (Defaults to None).
54+
signer : oci.signer.Signer, optional
55+
OCI authentication signer to initialize
56+
oci.data_science.DataScienceClient (Defaults to None).
57+
client_kwargs : dict, optional
58+
Additional client keyword arguments to initialize
59+
oci.data_science.DataScienceClient (Defaults to None).
60+
kwargs:
61+
Additional keyword arguments to initialize
62+
oci.data_science.DataScienceClient.
63+
"""
64+
self.id = id
65+
self._description = description
66+
self._percentage = 0
67+
self._status = None
68+
super().__init__(config, signer, client_kwargs, **kwargs)
69+
70+
71+
def _sync(self):
72+
"""Fetches the latest work request information to ADSWorkRequest object."""
73+
work_request = self.client.get_work_request(self.id).data
74+
work_request_logs = self.client.list_work_request_logs(
75+
self.id
76+
).data
77+
78+
self._percentage= work_request.percent_complete
79+
self._status = work_request.status
80+
self._description = work_request_logs[-1].message if work_request_logs else "Processing"
81+
82+
def watch(
83+
self,
84+
progress_callback: Callable,
85+
max_wait_time: int=DEFAULT_WAIT_TIME,
86+
poll_interval: int=DEFAULT_POLL_INTERVAL,
87+
):
88+
"""Updates the progress bar with realtime message and percentage until the process is completed.
89+
90+
Parameters
91+
----------
92+
progress_callback: Callable
93+
Progress bar callback function.
94+
It must accept `(percent_change, description)` where `percent_change` is the
95+
work request percent complete and `description` is the latest work request log message.
96+
max_wait_time: int
97+
Maximum amount of time to wait in seconds (Defaults to 1200).
98+
Negative implies infinite wait time.
99+
poll_interval: int
100+
Poll interval in seconds (Defaults to 10).
101+
102+
Returns
103+
-------
104+
None
105+
"""
106+
previous_percent_complete = 0
107+
108+
start_time = time.time()
109+
while self._percentage < 100:
110+
111+
seconds_since = time.time() - start_time
112+
if max_wait_time > 0 and seconds_since >= max_wait_time:
113+
logger.error(f"Exceeded max wait time of {max_wait_time} seconds.")
114+
return
115+
116+
time.sleep(poll_interval)
117+
118+
try:
119+
self._sync()
120+
except Exception as ex:
121+
logger.warn(ex)
122+
continue
123+
124+
percent_change = self._percentage - previous_percent_complete
125+
previous_percent_complete = self._percentage
126+
progress_callback(
127+
percent_change=percent_change,
128+
description=self._description
129+
)
130+
131+
if self._status in WORK_REQUEST_STOP_STATE:
132+
if self._status != oci.work_requests.models.WorkRequest.STATUS_SUCCEEDED:
133+
if self._description:
134+
raise Exception(self._description)
135+
else:
136+
raise Exception(
137+
"Error occurred in attempt to perform the operation. "
138+
"Check the service logs to get more details. "
139+
f"Work request id: {self.id}."
140+
)
141+
else:
142+
break
143+
144+
progress_callback(percent_change=0, description="Done")
145+
146+
def wait_work_request(
147+
self,
148+
progress_bar_description: str="Processing",
149+
max_wait_time: int=DEFAULT_WAIT_TIME,
150+
poll_interval: int=DEFAULT_POLL_INTERVAL
151+
):
152+
"""Waits for the work request progress bar to be completed.
153+
154+
Parameters
155+
----------
156+
progress_bar_description: str
157+
Progress bar initial step description (Defaults to `Processing`).
158+
max_wait_time: int
159+
Maximum amount of time to wait in seconds (Defaults to 1200).
160+
Negative implies infinite wait time.
161+
poll_interval: int
162+
Poll interval in seconds (Defaults to 10).
163+
164+
Returns
165+
-------
166+
None
167+
"""
168+
169+
with tqdm(
170+
total=WORK_REQUEST_PERCENTAGE,
171+
leave=False,
172+
mininterval=0,
173+
file=sys.stdout,
174+
desc=progress_bar_description,
175+
bar_format=DEFAULT_BAR_FORMAT
176+
) as pbar:
177+
178+
def progress_callback(percent_change, description):
179+
if percent_change != 0:
180+
pbar.update(percent_change)
181+
if description:
182+
pbar.set_description(description)
183+
184+
self.watch(
185+
progress_callback=progress_callback,
186+
max_wait_time=max_wait_time,
187+
poll_interval=poll_interval
188+
)
189+

0 commit comments

Comments
 (0)