|
8 | 8 | import shlex |
9 | 9 | import threading |
10 | 10 | from datetime import datetime, timedelta |
11 | | -from typing import Dict, List, Optional |
| 11 | +from typing import Dict, List, Optional, Union |
12 | 12 |
|
13 | 13 | from cachetools import TTLCache, cached |
14 | 14 | from oci.data_science.models import ModelDeploymentShapeSummary |
15 | 15 | from pydantic import ValidationError |
| 16 | +from rich.table import Table |
16 | 17 |
|
17 | 18 | from ads.aqua.app import AquaApp, logger |
18 | 19 | from ads.aqua.common.entities import ( |
|
44 | 45 | AQUA_MODEL_TYPE_SERVICE, |
45 | 46 | AQUA_MULTI_MODEL_CONFIG, |
46 | 47 | MODEL_BY_REFERENCE_OSS_PATH_KEY, |
| 48 | + MODEL_GROUP, |
47 | 49 | MODEL_NAME_DELIMITER, |
| 50 | + SINGLE_MODEL_FLEX, |
48 | 51 | UNKNOWN_DICT, |
| 52 | + UNKNOWN_ENUM_VALUE, |
49 | 53 | ) |
50 | 54 | from ads.aqua.data import AquaResourceIdentifier |
51 | 55 | from ads.aqua.model import AquaModelApp |
|
64 | 68 | ModelDeploymentConfigSummary, |
65 | 69 | MultiModelDeploymentConfigLoader, |
66 | 70 | ) |
67 | | -from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME |
| 71 | +from ads.aqua.modeldeployment.constants import ( |
| 72 | + DEFAULT_POLL_INTERVAL, |
| 73 | + DEFAULT_WAIT_TIME, |
| 74 | +) |
68 | 75 | from ads.aqua.modeldeployment.entities import ( |
69 | 76 | AquaDeployment, |
70 | 77 | AquaDeploymentDetail, |
71 | 78 | ConfigValidationError, |
72 | 79 | CreateModelDeploymentDetails, |
73 | 80 | ) |
74 | 81 | from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig |
| 82 | +from ads.aqua.shaperecommend.recommend import AquaShapeRecommend |
| 83 | +from ads.aqua.shaperecommend.shape_report import ( |
| 84 | + RequestRecommend, |
| 85 | + ShapeRecommendationReport, |
| 86 | +) |
75 | 87 | from ads.common.object_storage_details import ObjectStorageDetails |
76 | 88 | from ads.common.utils import UNKNOWN, get_log_links |
77 | 89 | from ads.common.work_request import DataScienceWorkRequest |
@@ -864,21 +876,26 @@ def list(self, **kwargs) -> List["AquaDeployment"]: |
864 | 876 |
|
865 | 877 | if oci_aqua: |
866 | 878 | # skipping the AQUA model deployments that are created from model group |
867 | | - # TODO: remove this checker after AQUA deployment is integrated with model group |
868 | | - aqua_model_id = model_deployment.freeform_tags.get( |
869 | | - Tags.AQUA_MODEL_ID_TAG, UNKNOWN |
870 | | - ) |
871 | 879 | if ( |
872 | | - "datasciencemodelgroup" in aqua_model_id |
873 | | - or model_deployment.model_deployment_configuration_details.deployment_type |
874 | | - == "UNKNOWN_ENUM_VALUE" |
| 880 | + model_deployment.model_deployment_configuration_details.deployment_type |
| 881 | + in [UNKNOWN_ENUM_VALUE, MODEL_GROUP, SINGLE_MODEL_FLEX] |
875 | 882 | ): |
876 | 883 | continue |
877 | | - results.append( |
878 | | - AquaDeployment.from_oci_model_deployment( |
879 | | - model_deployment, self.region |
| 884 | + try: |
| 885 | + results.append( |
| 886 | + AquaDeployment.from_oci_model_deployment( |
| 887 | + model_deployment, self.region |
| 888 | + ) |
880 | 889 | ) |
881 | | - ) |
| 890 | + except Exception as e: |
| 891 | + logger.error( |
| 892 | + f"There was an issue processing the list of model deployments . Error: {str(e)}", |
| 893 | + exc_info=True, |
| 894 | + ) |
| 895 | + raise AquaRuntimeError( |
| 896 | + f"There was an issue processing the list of model deployments . Error: {str(e)}" |
| 897 | + ) from e |
| 898 | + |
882 | 899 | # log telemetry if MD is in active or failed state |
883 | 900 | deployment_id = model_deployment.id |
884 | 901 | state = model_deployment.lifecycle_state.upper() |
@@ -1249,6 +1266,50 @@ def validate_deployment_params( |
1249 | 1266 | ) |
1250 | 1267 | return {"valid": True} |
1251 | 1268 |
|
| 1269 | + def recommend_shape(self, **kwargs) -> Union[Table, ShapeRecommendationReport]: |
| 1270 | + """ |
| 1271 | + For the CLI (set generate_table = True), generates the table (in rich diff) with valid |
| 1272 | + GPU deployment shapes for the provided model and configuration. |
| 1273 | +
|
| 1274 | + For the API (set generate_table = False), generates the JSON with valid |
| 1275 | + GPU deployment shapes for the provided model and configuration. |
| 1276 | +
|
| 1277 | + Validates if recommendations are generated, calls method to construct the rich diff |
| 1278 | + table with the recommendation data. |
| 1279 | +
|
| 1280 | + Parameters |
| 1281 | + ---------- |
| 1282 | + model_ocid : str |
| 1283 | + OCID of the model to recommend feasible compute shapes. |
| 1284 | +
|
| 1285 | + Returns |
| 1286 | + ------- |
| 1287 | + Table (generate_table = True) |
| 1288 | + A table format for the recommendation report with compatible deployment shapes |
| 1289 | + or troubleshooting info citing the largest shapes if no shape is suitable. |
| 1290 | +
|
| 1291 | + ShapeRecommendationReport (generate_table = False) |
| 1292 | + A recommendation report with compatible deployment shapes, or troubleshooting info |
| 1293 | + citing the largest shapes if no shape is suitable. |
| 1294 | +
|
| 1295 | + Raises |
| 1296 | + ------ |
| 1297 | + AquaValueError |
| 1298 | + If model type is unsupported by tool (no recommendation report generated) |
| 1299 | + """ |
| 1300 | + try: |
| 1301 | + request = RequestRecommend(**kwargs) |
| 1302 | + except ValidationError as e: |
| 1303 | + custom_error = build_pydantic_error_message(e) |
| 1304 | + raise AquaValueError( # noqa: B904 |
| 1305 | + f"Failed to request shape recommendation due to invalid input parameters: {custom_error}" |
| 1306 | + ) |
| 1307 | + |
| 1308 | + shape_recommend = AquaShapeRecommend() |
| 1309 | + shape_recommend_report = shape_recommend.which_shapes(request) |
| 1310 | + |
| 1311 | + return shape_recommend_report |
| 1312 | + |
1252 | 1313 | @telemetry(entry_point="plugin=deployment&action=list_shapes", name="aqua") |
1253 | 1314 | @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=5), timer=datetime.now)) |
1254 | 1315 | def list_shapes(self, **kwargs) -> List[ComputeShapeSummary]: |
|
0 commit comments