|
5 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
6 | 6 |
|
7 | 7 | from functools import wraps |
8 | | -import json |
9 | | -import time |
10 | 8 | import logging |
11 | 9 | from typing import Callable, List |
12 | 10 | from ads.common.oci_datascience import OCIDataScienceMixin |
13 | | -from ads.common import utils as progress_bar_utils |
| 11 | +from ads.common.oci_mixin import OCIWorkRequestMixin |
14 | 12 | from ads.config import PROJECT_OCID |
15 | | -from ads.model.deployment.common import utils |
16 | 13 | from ads.model.deployment.common.utils import OCIClientManager, State |
17 | 14 | import oci |
18 | 15 |
|
@@ -84,6 +81,7 @@ class MissingModelDeploymentWorkflowIdError(Exception): # pragma: no cover |
84 | 81 |
|
85 | 82 | class OCIDataScienceModelDeployment( |
86 | 83 | OCIDataScienceMixin, |
| 84 | + OCIWorkRequestMixin, |
87 | 85 | oci.data_science.models.ModelDeployment, |
88 | 86 | ): |
89 | 87 | """Represents an OCI Data Science Model Deployment. |
@@ -188,20 +186,13 @@ def activate( |
188 | 186 | if wait_for_completion: |
189 | 187 |
|
190 | 188 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
191 | | - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
192 | | - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
193 | | - model_deployment_id = self.id |
194 | 189 |
|
195 | 190 | try: |
196 | | - self._wait_for_progress_completion( |
197 | | - State.ACTIVE.name, |
198 | | - ACTIVATE_WORKFLOW_STEPS, |
199 | | - [State.FAILED.name, State.INACTIVE.name], |
200 | | - self.workflow_req_id, |
201 | | - current_state, |
202 | | - model_deployment_id, |
203 | | - max_wait_time, |
204 | | - poll_interval, |
| 191 | + self.wait_for_progress( |
| 192 | + self.workflow_req_id, |
| 193 | + ACTIVATE_WORKFLOW_STEPS, |
| 194 | + max_wait_time, |
| 195 | + poll_interval |
205 | 196 | ) |
206 | 197 | except Exception as e: |
207 | 198 | logger.error( |
@@ -243,20 +234,13 @@ def create( |
243 | 234 | if wait_for_completion: |
244 | 235 |
|
245 | 236 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
246 | | - res_payload = json.loads(str(response.data)) |
247 | | - current_state = State._from_str(res_payload["lifecycle_state"]) |
248 | | - model_deployment_id = self.id |
249 | 237 |
|
250 | 238 | try: |
251 | | - self._wait_for_progress_completion( |
252 | | - State.ACTIVE.name, |
253 | | - CREATE_WORKFLOW_STEPS, |
254 | | - [State.FAILED.name, State.INACTIVE.name], |
255 | | - self.workflow_req_id, |
256 | | - current_state, |
257 | | - model_deployment_id, |
258 | | - max_wait_time, |
259 | | - poll_interval, |
| 239 | + self.wait_for_progress( |
| 240 | + self.workflow_req_id, |
| 241 | + CREATE_WORKFLOW_STEPS, |
| 242 | + max_wait_time, |
| 243 | + poll_interval |
260 | 244 | ) |
261 | 245 | except Exception as e: |
262 | 246 | logger.error( |
@@ -301,20 +285,13 @@ def deactivate( |
301 | 285 | if wait_for_completion: |
302 | 286 |
|
303 | 287 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
304 | | - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
305 | | - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
306 | | - model_deployment_id = self.id |
307 | 288 |
|
308 | 289 | try: |
309 | | - self._wait_for_progress_completion( |
310 | | - State.INACTIVE.name, |
311 | | - DEACTIVATE_WORKFLOW_STEPS, |
312 | | - [State.FAILED.name], |
313 | | - self.workflow_req_id, |
314 | | - current_state, |
315 | | - model_deployment_id, |
316 | | - max_wait_time, |
317 | | - poll_interval, |
| 290 | + self.wait_for_progress( |
| 291 | + self.workflow_req_id, |
| 292 | + DEACTIVATE_WORKFLOW_STEPS, |
| 293 | + max_wait_time, |
| 294 | + poll_interval |
318 | 295 | ) |
319 | 296 | except Exception as e: |
320 | 297 | logger.error( |
@@ -359,20 +336,13 @@ def delete( |
359 | 336 | if wait_for_completion: |
360 | 337 |
|
361 | 338 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
362 | | - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
363 | | - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
364 | | - model_deployment_id = self.id |
365 | 339 |
|
366 | 340 | try: |
367 | | - self._wait_for_progress_completion( |
368 | | - State.DELETED.name, |
369 | | - DELETE_WORKFLOW_STEPS, |
370 | | - [State.FAILED.name, State.INACTIVE.name], |
371 | | - self.workflow_req_id, |
372 | | - current_state, |
373 | | - model_deployment_id, |
374 | | - max_wait_time, |
375 | | - poll_interval, |
| 341 | + self.wait_for_progress( |
| 342 | + self.workflow_req_id, |
| 343 | + DELETE_WORKFLOW_STEPS, |
| 344 | + max_wait_time, |
| 345 | + poll_interval |
376 | 346 | ) |
377 | 347 | except Exception as e: |
378 | 348 | logger.error( |
@@ -511,90 +481,3 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment": |
511 | 481 | An instance of `OCIDataScienceModelDeployment`. |
512 | 482 | """ |
513 | 483 | return super().from_ocid(model_deployment_id) |
514 | | - |
515 | | - def _wait_for_progress_completion( |
516 | | - self, |
517 | | - final_state: str, |
518 | | - work_flow_step: int, |
519 | | - disallowed_final_states: List[str], |
520 | | - work_flow_request_id: str, |
521 | | - state: State, |
522 | | - model_deployment_id: str, |
523 | | - max_wait_time: int = DEFAULT_WAIT_TIME, |
524 | | - poll_interval: int = DEFAULT_POLL_INTERVAL, |
525 | | - ): |
526 | | - """_wait_for_progress_completion blocks until progress is completed. |
527 | | -
|
528 | | - Parameters |
529 | | - ---------- |
530 | | - final_state: str |
531 | | - Final state of model deployment aimed to be reached. |
532 | | - work_flow_step: int |
533 | | - Number of work flow step of the request. |
534 | | - disallowed_final_states: list[str] |
535 | | - List of disallowed final state to be reached. |
536 | | - work_flow_request_id: str |
537 | | - The id of work flow request. |
538 | | - state: State |
539 | | - The current state of model deployment. |
540 | | - model_deployment_id: str |
541 | | - The ocid of model deployment. |
542 | | - max_wait_time: int |
543 | | - Maximum amount of time to wait in seconds (Defaults to 1200). |
544 | | - Negative implies infinite wait time. |
545 | | - poll_interval: int |
546 | | - Poll interval in seconds (Defaults to 10). |
547 | | - """ |
548 | | - |
549 | | - start_time = time.time() |
550 | | - prev_message = "" |
551 | | - prev_workflow_stage_len = 0 |
552 | | - current_state = state or State.UNKNOWN |
553 | | - with progress_bar_utils.get_progress_bar(work_flow_step) as progress: |
554 | | - if max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time: |
555 | | - utils.get_logger().error( |
556 | | - f"Max wait time ({max_wait_time} seconds) exceeded." |
557 | | - ) |
558 | | - while ( |
559 | | - max_wait_time < 0 or utils.seconds_since(start_time) < max_wait_time |
560 | | - ) and current_state.name.upper() != final_state: |
561 | | - if current_state.name.upper() in disallowed_final_states: |
562 | | - utils.get_logger().info( |
563 | | - f"Operation failed due to deployment reaching state {current_state.name.upper()}. Use Deployment ID for further steps." |
564 | | - ) |
565 | | - break |
566 | | - |
567 | | - prev_state = current_state.name |
568 | | - try: |
569 | | - model_deployment_payload = json.loads( |
570 | | - str(self.client.get_model_deployment(model_deployment_id).data) |
571 | | - ) |
572 | | - current_state = ( |
573 | | - State._from_str(model_deployment_payload["lifecycle_state"]) |
574 | | - if "lifecycle_state" in model_deployment_payload |
575 | | - else State.UNKNOWN |
576 | | - ) |
577 | | - workflow_payload = self.client.list_work_request_logs( |
578 | | - work_flow_request_id |
579 | | - ).data |
580 | | - if isinstance(workflow_payload, list) and len(workflow_payload) > 0: |
581 | | - if prev_message != workflow_payload[-1].message: |
582 | | - for _ in range( |
583 | | - len(workflow_payload) - prev_workflow_stage_len |
584 | | - ): |
585 | | - progress.update(workflow_payload[-1].message) |
586 | | - prev_workflow_stage_len = len(workflow_payload) |
587 | | - prev_message = workflow_payload[-1].message |
588 | | - prev_workflow_stage_len = len(workflow_payload) |
589 | | - if prev_state != current_state.name: |
590 | | - utils.get_logger().info( |
591 | | - f"Status Update: {current_state.name} in {utils.seconds_since(start_time)} seconds" |
592 | | - ) |
593 | | - except Exception as e: |
594 | | - # utils.get_logger().warning( |
595 | | - # "Unable to update deployment status. Details: %s", format( |
596 | | - # e) |
597 | | - # ) |
598 | | - pass |
599 | | - time.sleep(poll_interval) |
600 | | - progress.update("Done") |
0 commit comments