|
1 | 1 | #!/usr/bin/env python |
2 | | -# -*- coding: utf-8; -*- |
3 | 2 |
|
4 | | -# Copyright (c) 2024 Oracle and/or its affiliates. |
| 3 | +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. |
5 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
6 | 5 |
|
7 | | -from functools import wraps |
8 | 6 | import logging |
9 | | -from typing import Callable, List |
10 | | -from ads.common.oci_datascience import OCIDataScienceMixin |
11 | | -from ads.common.work_request import DataScienceWorkRequest |
12 | | -from ads.config import PROJECT_OCID |
13 | | -from ads.model.deployment.common.utils import OCIClientManager, State |
14 | | -import oci |
| 7 | +from functools import wraps |
| 8 | +from typing import Callable, List, Optional |
15 | 9 |
|
| 10 | +import oci |
16 | 11 | from oci.data_science.models import ( |
17 | 12 | CreateModelDeploymentDetails, |
| 13 | + ModelDeploymentShapeSummary, |
18 | 14 | UpdateModelDeploymentDetails, |
19 | 15 | ) |
20 | 16 |
|
| 17 | +from ads.common.oci_datascience import OCIDataScienceMixin |
| 18 | +from ads.common.work_request import DataScienceWorkRequest |
| 19 | +from ads.config import COMPARTMENT_OCID, PROJECT_OCID |
| 20 | +from ads.model.deployment.common.utils import OCIClientManager, State |
| 21 | + |
21 | 22 | DEFAULT_WAIT_TIME = 1200 |
22 | 23 | DEFAULT_POLL_INTERVAL = 10 |
23 | 24 | ALLOWED_STATUS = [ |
@@ -185,14 +186,13 @@ def activate( |
185 | 186 | self.id, |
186 | 187 | ) |
187 | 188 |
|
188 | | - |
189 | 189 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
190 | 190 | if wait_for_completion: |
191 | 191 | try: |
192 | 192 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request( |
193 | 193 | progress_bar_description="Activating model deployment", |
194 | | - max_wait_time=max_wait_time, |
195 | | - poll_interval=poll_interval |
| 194 | + max_wait_time=max_wait_time, |
| 195 | + poll_interval=poll_interval, |
196 | 196 | ) |
197 | 197 | except Exception as e: |
198 | 198 | logger.error( |
@@ -239,8 +239,8 @@ def create( |
239 | 239 | try: |
240 | 240 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request( |
241 | 241 | progress_bar_description="Creating model deployment", |
242 | | - max_wait_time=max_wait_time, |
243 | | - poll_interval=poll_interval |
| 242 | + max_wait_time=max_wait_time, |
| 243 | + poll_interval=poll_interval, |
244 | 244 | ) |
245 | 245 | except Exception as e: |
246 | 246 | logger.error("Error while trying to create model deployment: " + str(e)) |
@@ -290,8 +290,8 @@ def deactivate( |
290 | 290 | try: |
291 | 291 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request( |
292 | 292 | progress_bar_description="Deactivating model deployment", |
293 | | - max_wait_time=max_wait_time, |
294 | | - poll_interval=poll_interval |
| 293 | + max_wait_time=max_wait_time, |
| 294 | + poll_interval=poll_interval, |
295 | 295 | ) |
296 | 296 | except Exception as e: |
297 | 297 | logger.error( |
@@ -351,14 +351,14 @@ def delete( |
351 | 351 | response = self.client.delete_model_deployment( |
352 | 352 | self.id, |
353 | 353 | ) |
354 | | - |
| 354 | + |
355 | 355 | self.workflow_req_id = response.headers.get("opc-work-request-id", None) |
356 | 356 | if wait_for_completion: |
357 | 357 | try: |
358 | 358 | DataScienceWorkRequest(self.workflow_req_id).wait_work_request( |
359 | 359 | progress_bar_description="Deleting model deployment", |
360 | | - max_wait_time=max_wait_time, |
361 | | - poll_interval=poll_interval |
| 360 | + max_wait_time=max_wait_time, |
| 361 | + poll_interval=poll_interval, |
362 | 362 | ) |
363 | 363 | except Exception as e: |
364 | 364 | logger.error("Error while trying to delete model deployment: " + str(e)) |
@@ -493,3 +493,30 @@ def from_id(cls, model_deployment_id: str) -> "OCIDataScienceModelDeployment": |
493 | 493 | An instance of `OCIDataScienceModelDeployment`. |
494 | 494 | """ |
495 | 495 | return super().from_ocid(model_deployment_id) |
| 496 | + |
| 497 | + @classmethod |
| 498 | + def shapes( |
| 499 | + cls, |
| 500 | + compartment_id: Optional[str] = None, |
| 501 | + **kwargs, |
| 502 | + ) -> List[ModelDeploymentShapeSummary]: |
| 503 | + """ |
| 504 | + Retrieves all available model deployment shapes in the given compartment. |
| 505 | +
|
| 506 | + This method uses OCI's pagination utility to fetch all pages of model |
| 507 | + deployment shape summaries available in the specified compartment. |
| 508 | +
|
| 509 | + Args: |
| 510 | + compartment_id (Optional[str]): The OCID of the compartment. If not provided, |
| 511 | + the default COMPARTMENT_ID extracted form env variables is used. |
| 512 | + **kwargs: Additional keyword arguments to pass to the list_model_deployments call. |
| 513 | +
|
| 514 | + Returns: |
| 515 | + List[ModelDeploymentShapeSummary]: A list of all model deployment shape summaries. |
| 516 | + """ |
| 517 | + client = cls().client |
| 518 | + compartment_id = compartment_id or COMPARTMENT_OCID |
| 519 | + |
| 520 | + return oci.pagination.list_call_get_all_results( |
| 521 | + client.list_model_deployment_shapes, compartment_id, **kwargs |
| 522 | + ).data |
0 commit comments