|
5 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
6 | 6 |
|
7 | 7 | from functools import wraps |
8 | | -import time |
9 | 8 | import logging |
10 | 9 | from typing import Callable, List |
11 | 10 | from ads.common.oci_datascience import OCIDataScienceMixin |
12 | | -from ads.common import utils as progress_bar_utils |
| 11 | +from ads.common.oci_mixin import OCIWorkRequestMixin |
13 | 12 | from ads.config import PROJECT_OCID |
14 | | -from ads.model.deployment.common import utils |
15 | 13 | from ads.model.deployment.common.utils import OCIClientManager, State |
16 | 14 | import oci |
17 | 15 |
|
18 | 16 | from oci.data_science.models import ( |
19 | 17 | CreateModelDeploymentDetails, |
20 | 18 | UpdateModelDeploymentDetails, |
21 | | - WorkRequest |
22 | 19 | ) |
23 | 20 |
|
24 | 21 | DEFAULT_WAIT_TIME = 1200 |
@@ -84,6 +81,7 @@ class MissingModelDeploymentWorkflowIdError(Exception): |
84 | 81 |
|
85 | 82 | class OCIDataScienceModelDeployment( |
86 | 83 | OCIDataScienceMixin, |
| 84 | + OCIWorkRequestMixin, |
87 | 85 | oci.data_science.models.ModelDeployment, |
88 | 86 | ): |
89 | 87 | """Represents an OCI Data Science Model Deployment. |
@@ -190,7 +188,7 @@ def activate( |
190 | 188 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
191 | 189 |
|
192 | 190 | try: |
193 | | - self._wait_for_work_request( |
| 191 | + self.wait_for_progress( |
194 | 192 | self.workflow_req_id, |
195 | 193 | ACTIVATE_WORKFLOW_STEPS, |
196 | 194 | max_wait_time, |
@@ -238,7 +236,7 @@ def create( |
238 | 236 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
239 | 237 |
|
240 | 238 | try: |
241 | | - self._wait_for_work_request( |
| 239 | + self.wait_for_progress( |
242 | 240 | self.workflow_req_id, |
243 | 241 | CREATE_WORKFLOW_STEPS, |
244 | 242 | max_wait_time, |
@@ -289,7 +287,7 @@ def deactivate( |
289 | 287 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
290 | 288 |
|
291 | 289 | try: |
292 | | - self._wait_for_work_request( |
| 290 | + self.wait_for_progress( |
293 | 291 | self.workflow_req_id, |
294 | 292 | DEACTIVATE_WORKFLOW_STEPS, |
295 | 293 | max_wait_time, |
@@ -340,7 +338,7 @@ def delete( |
340 | 338 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
341 | 339 |
|
342 | 340 | try: |
343 | | - self._wait_for_work_request( |
| 341 | + self.wait_for_progress( |
344 | 342 | self.workflow_req_id, |
345 | 343 | DELETE_WORKFLOW_STEPS, |
346 | 344 | max_wait_time, |
@@ -483,77 +481,3 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment": |
483 | 481 | An instance of `OCIDataScienceModelDeployment`. |
484 | 482 | """ |
485 | 483 | return super().from_ocid(model_deployment_id) |
486 | | - |
487 | | - def _wait_for_work_request( |
488 | | - self, |
489 | | - work_request_id: str, |
490 | | - num_steps: int = DELETE_WORKFLOW_STEPS, |
491 | | - max_wait_time: int = DEFAULT_WAIT_TIME, |
492 | | - poll_interval: int = DEFAULT_POLL_INTERVAL |
493 | | - ) -> None: |
494 | | - """Waits for the work request to be completed. |
495 | | -
|
496 | | - Parameters |
497 | | - ---------- |
498 | | - work_request_id: str |
499 | | - Work Request OCID. |
500 | | - num_steps: (int, optional). Defaults to 6. |
501 | | - Number of steps for the progress indicator. |
502 | | - max_wait_time: int |
503 | | - Maximum amount of time to wait in seconds (Defaults to 1200). |
504 | | - Negative implies infinite wait time. |
505 | | - poll_interval: int |
506 | | - Poll interval in seconds (Defaults to 10). |
507 | | -
|
508 | | - Returns |
509 | | - ------- |
510 | | - None |
511 | | - """ |
512 | | - STOP_STATE = ( |
513 | | - WorkRequest.STATUS_SUCCEEDED, |
514 | | - WorkRequest.STATUS_CANCELED, |
515 | | - WorkRequest.STATUS_FAILED, |
516 | | - ) |
517 | | - work_request_logs = [] |
518 | | - |
519 | | - i = 0 |
520 | | - start_time = time.time() |
521 | | - with progress_bar_utils.get_progress_bar(num_steps) as progress: |
522 | | - exceed_max_time = max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time |
523 | | - if exceed_max_time: |
524 | | - logger.error( |
525 | | - f"Max wait time ({max_wait_time} seconds) exceeded." |
526 | | - ) |
527 | | - while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps): |
528 | | - time.sleep(poll_interval) |
529 | | - new_work_request_logs = [] |
530 | | - |
531 | | - try: |
532 | | - work_request = self.client.get_work_request(work_request_id).data |
533 | | - work_request_logs = self.client.list_work_request_logs( |
534 | | - work_request_id |
535 | | - ).data |
536 | | - except Exception as ex: |
537 | | - logger.warn(ex) |
538 | | - |
539 | | - new_work_request_logs = ( |
540 | | - work_request_logs[i:] if work_request_logs else [] |
541 | | - ) |
542 | | - |
543 | | - for wr_item in new_work_request_logs: |
544 | | - progress.update(wr_item.message) |
545 | | - i += 1 |
546 | | - |
547 | | - if work_request and work_request.status in STOP_STATE: |
548 | | - if work_request.status != WorkRequest.STATUS_SUCCEEDED: |
549 | | - if new_work_request_logs: |
550 | | - raise Exception(new_work_request_logs[-1].message) |
551 | | - else: |
552 | | - raise Exception( |
553 | | - "Error occurred in attempt to perform the operation. " |
554 | | - "Check the service logs to get more details. " |
555 | | - f"{work_request}" |
556 | | - ) |
557 | | - else: |
558 | | - break |
559 | | - progress.update("Done") |
0 commit comments