|
5 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
6 | 6 |
|
7 | 7 | from functools import wraps |
8 | | -import json |
9 | 8 | import time |
10 | 9 | import logging |
11 | 10 | from typing import Callable, List |
|
19 | 18 | from oci.data_science.models import ( |
20 | 19 | CreateModelDeploymentDetails, |
21 | 20 | UpdateModelDeploymentDetails, |
| 21 | + WorkRequest |
22 | 22 | ) |
23 | 23 |
|
24 | 24 | DEFAULT_WAIT_TIME = 1200 |
@@ -188,20 +188,13 @@ def activate( |
188 | 188 | if wait_for_completion: |
189 | 189 |
|
190 | 190 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
191 | | - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
192 | | - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
193 | | - model_deployment_id = self.id |
194 | 191 |
|
195 | 192 | try: |
196 | | - self._wait_for_progress_completion( |
197 | | - State.ACTIVE.name, |
198 | | - ACTIVATE_WORKFLOW_STEPS, |
199 | | - [State.FAILED.name, State.INACTIVE.name], |
200 | | - self.workflow_req_id, |
201 | | - current_state, |
202 | | - model_deployment_id, |
203 | | - max_wait_time, |
204 | | - poll_interval, |
| 193 | + self._wait_for_work_request( |
| 194 | + self.workflow_req_id, |
| 195 | + ACTIVATE_WORKFLOW_STEPS, |
| 196 | + max_wait_time, |
| 197 | + poll_interval |
205 | 198 | ) |
206 | 199 | except Exception as e: |
207 | 200 | logger.error( |
@@ -243,20 +236,13 @@ def create( |
243 | 236 | if wait_for_completion: |
244 | 237 |
|
245 | 238 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
246 | | - res_payload = json.loads(str(response.data)) |
247 | | - current_state = State._from_str(res_payload["lifecycle_state"]) |
248 | | - model_deployment_id = self.id |
249 | 239 |
|
250 | 240 | try: |
251 | | - self._wait_for_progress_completion( |
252 | | - State.ACTIVE.name, |
253 | | - CREATE_WORKFLOW_STEPS, |
254 | | - [State.FAILED.name, State.INACTIVE.name], |
255 | | - self.workflow_req_id, |
256 | | - current_state, |
257 | | - model_deployment_id, |
258 | | - max_wait_time, |
259 | | - poll_interval, |
| 241 | + self._wait_for_work_request( |
| 242 | + self.workflow_req_id, |
| 243 | + CREATE_WORKFLOW_STEPS, |
| 244 | + max_wait_time, |
| 245 | + poll_interval |
260 | 246 | ) |
261 | 247 | except Exception as e: |
262 | 248 | logger.error( |
@@ -301,20 +287,13 @@ def deactivate( |
301 | 287 | if wait_for_completion: |
302 | 288 |
|
303 | 289 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
304 | | - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
305 | | - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
306 | | - model_deployment_id = self.id |
307 | 290 |
|
308 | 291 | try: |
309 | | - self._wait_for_progress_completion( |
310 | | - State.INACTIVE.name, |
311 | | - DEACTIVATE_WORKFLOW_STEPS, |
312 | | - [State.FAILED.name], |
313 | | - self.workflow_req_id, |
314 | | - current_state, |
315 | | - model_deployment_id, |
316 | | - max_wait_time, |
317 | | - poll_interval, |
| 292 | + self._wait_for_work_request( |
| 293 | + self.workflow_req_id, |
| 294 | + DEACTIVATE_WORKFLOW_STEPS, |
| 295 | + max_wait_time, |
| 296 | + poll_interval |
318 | 297 | ) |
319 | 298 | except Exception as e: |
320 | 299 | logger.error( |
@@ -359,20 +338,13 @@ def delete( |
359 | 338 | if wait_for_completion: |
360 | 339 |
|
361 | 340 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
362 | | - oci_model_deployment_object = self.client.get_model_deployment(self.id).data |
363 | | - current_state = State._from_str(oci_model_deployment_object.lifecycle_state) |
364 | | - model_deployment_id = self.id |
365 | 341 |
|
366 | 342 | try: |
367 | | - self._wait_for_progress_completion( |
368 | | - State.DELETED.name, |
369 | | - DELETE_WORKFLOW_STEPS, |
370 | | - [State.FAILED.name, State.INACTIVE.name], |
371 | | - self.workflow_req_id, |
372 | | - current_state, |
373 | | - model_deployment_id, |
374 | | - max_wait_time, |
375 | | - poll_interval, |
| 343 | + self._wait_for_work_request( |
| 344 | + self.workflow_req_id, |
| 345 | + DELETE_WORKFLOW_STEPS, |
| 346 | + max_wait_time, |
| 347 | + poll_interval |
376 | 348 | ) |
377 | 349 | except Exception as e: |
378 | 350 | logger.error( |
@@ -512,89 +484,76 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment": |
512 | 484 | """ |
513 | 485 | return super().from_ocid(model_deployment_id) |
514 | 486 |
|
515 | | - def _wait_for_progress_completion( |
516 | | - self, |
517 | | - final_state: str, |
518 | | - work_flow_step: int, |
519 | | - disallowed_final_states: List[str], |
520 | | - work_flow_request_id: str, |
521 | | - state: State, |
522 | | - model_deployment_id: str, |
523 | | - max_wait_time: int = DEFAULT_WAIT_TIME, |
524 | | - poll_interval: int = DEFAULT_POLL_INTERVAL, |
525 | | - ): |
526 | | - """_wait_for_progress_completion blocks until progress is completed. |
| 487 | + def _wait_for_work_request( |
| 488 | + self, |
| 489 | + work_request_id: str, |
| 490 | + num_steps: int = DELETE_WORKFLOW_STEPS, |
| 491 | + max_wait_time: int = DEFAULT_WAIT_TIME, |
| 492 | + poll_interval: int = DEFAULT_POLL_INTERVAL |
| 493 | + ) -> None: |
| 494 | + """Waits for the work request to be completed. |
527 | 495 |
|
528 | 496 | Parameters |
529 | 497 | ---------- |
530 | | - final_state: str |
531 | | - Final state of model deployment aimed to be reached. |
532 | | - work_flow_step: int |
533 | | - Number of work flow step of the request. |
534 | | - disallowed_final_states: list[str] |
535 | | - List of disallowed final state to be reached. |
536 | | - work_flow_request_id: str |
537 | | - The id of work flow request. |
538 | | - state: State |
539 | | - The current state of model deployment. |
540 | | - model_deployment_id: str |
541 | | - The ocid of model deployment. |
| 498 | + work_request_id: str |
| 499 | + Work Request OCID. |
| 500 | + num_steps: (int, optional). Defaults to 6. |
| 501 | + Number of steps for the progress indicator. |
542 | 502 | max_wait_time: int |
543 | 503 | Maximum amount of time to wait in seconds (Defaults to 1200). |
544 | 504 | Negative implies infinite wait time. |
545 | 505 | poll_interval: int |
546 | 506 | Poll interval in seconds (Defaults to 10). |
| 507 | +
|
| 508 | + Returns |
| 509 | + ------- |
| 510 | + None |
547 | 511 | """ |
| 512 | + STOP_STATE = ( |
| 513 | + WorkRequest.STATUS_SUCCEEDED, |
| 514 | + WorkRequest.STATUS_CANCELED, |
| 515 | + WorkRequest.STATUS_FAILED, |
| 516 | + ) |
| 517 | + work_request_logs = [] |
548 | 518 |
|
| 519 | + i = 0 |
549 | 520 | start_time = time.time() |
550 | | - prev_message = "" |
551 | | - prev_workflow_stage_len = 0 |
552 | | - current_state = state or State.UNKNOWN |
553 | | - with progress_bar_utils.get_progress_bar(work_flow_step) as progress: |
554 | | - if max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time: |
555 | | - utils.get_logger().error( |
| 521 | + with progress_bar_utils.get_progress_bar(num_steps) as progress: |
| 522 | + exceed_max_time = max_wait_time > 0 and utils.seconds_since(start_time) >= max_wait_time |
| 523 | + if exceed_max_time: |
| 524 | + logger.error( |
556 | 525 | f"Max wait time ({max_wait_time} seconds) exceeded." |
557 | 526 | ) |
558 | | - while ( |
559 | | - max_wait_time < 0 or utils.seconds_since(start_time) < max_wait_time |
560 | | - ) and current_state.name.upper() != final_state: |
561 | | - if current_state.name.upper() in disallowed_final_states: |
562 | | - utils.get_logger().info( |
563 | | - f"Operation failed due to deployment reaching state {current_state.name.upper()}. Use Deployment ID for further steps." |
564 | | - ) |
565 | | - break |
566 | | - |
567 | | - prev_state = current_state.name |
| 527 | + while not exceed_max_time and (not work_request_logs or len(work_request_logs) < num_steps): |
| 528 | + time.sleep(poll_interval) |
| 529 | + new_work_request_logs = [] |
| 530 | + |
568 | 531 | try: |
569 | | - model_deployment_payload = json.loads( |
570 | | - str(self.client.get_model_deployment(model_deployment_id).data) |
571 | | - ) |
572 | | - current_state = ( |
573 | | - State._from_str(model_deployment_payload["lifecycle_state"]) |
574 | | - if "lifecycle_state" in model_deployment_payload |
575 | | - else State.UNKNOWN |
576 | | - ) |
577 | | - workflow_payload = self.client.list_work_request_logs( |
578 | | - work_flow_request_id |
| 532 | + work_request = self.client.get_work_request(work_request_id).data |
| 533 | + work_request_logs = self.client.list_work_request_logs( |
| 534 | + work_request_id |
579 | 535 | ).data |
580 | | - if isinstance(workflow_payload, list) and len(workflow_payload) > 0: |
581 | | - if prev_message != workflow_payload[-1].message: |
582 | | - for _ in range( |
583 | | - len(workflow_payload) - prev_workflow_stage_len |
584 | | - ): |
585 | | - progress.update(workflow_payload[-1].message) |
586 | | - prev_workflow_stage_len = len(workflow_payload) |
587 | | - prev_message = workflow_payload[-1].message |
588 | | - prev_workflow_stage_len = len(workflow_payload) |
589 | | - if prev_state != current_state.name: |
590 | | - utils.get_logger().info( |
591 | | - f"Status Update: {current_state.name} in {utils.seconds_since(start_time)} seconds" |
592 | | - ) |
593 | | - except Exception as e: |
594 | | - # utils.get_logger().warning( |
595 | | - # "Unable to update deployment status. Details: %s", format( |
596 | | - # e) |
597 | | - # ) |
598 | | - pass |
599 | | - time.sleep(poll_interval) |
| 536 | + except Exception as ex: |
| 537 | + logger.warn(ex) |
| 538 | + |
| 539 | + new_work_request_logs = ( |
| 540 | + work_request_logs[i:] if work_request_logs else [] |
| 541 | + ) |
| 542 | + |
| 543 | + for wr_item in new_work_request_logs: |
| 544 | + progress.update(wr_item.message) |
| 545 | + i += 1 |
| 546 | + |
| 547 | + if work_request and work_request.status in STOP_STATE: |
| 548 | + if work_request.status != WorkRequest.STATUS_SUCCEEDED: |
| 549 | + if new_work_request_logs: |
| 550 | + raise Exception(new_work_request_logs[-1].message) |
| 551 | + else: |
| 552 | + raise Exception( |
| 553 | + "Error occurred in attempt to perform the operation. " |
| 554 | + "Check the service logs to get more details. " |
| 555 | + f"{work_request}" |
| 556 | + ) |
| 557 | + else: |
| 558 | + break |
600 | 559 | progress.update("Done") |
0 commit comments