Skip to content

Commit 19de240

Browse files
authored
Merge branch 'main' into ODSC-76209/GPU-Shape-Recommendation
2 parents 015aa56 + 4bb5fb6 commit 19de240

File tree

17 files changed

+1122
-58
lines changed

17 files changed

+1122
-58
lines changed

ads/aqua/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@
5656
MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location"
5757

5858
AQUA_CHAT_TEMPLATE_METADATA_KEY = "chat_template"
59+
UNKNOWN_ENUM_VALUE = "UNKNOWN_ENUM_VALUE"
60+
MODEL_GROUP = "MODEL_GROUP"
61+
SINGLE_MODEL_FLEX = "SINGLE_MODEL_FLEX"
5962

6063
CONSOLE_LINK_RESOURCE_TYPE_MAPPING = {
6164
"datasciencemodel": "models",

ads/aqua/modeldeployment/deployment.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,11 @@
4545
AQUA_MODEL_TYPE_SERVICE,
4646
AQUA_MULTI_MODEL_CONFIG,
4747
MODEL_BY_REFERENCE_OSS_PATH_KEY,
48+
MODEL_GROUP,
4849
MODEL_NAME_DELIMITER,
50+
SINGLE_MODEL_FLEX,
4951
UNKNOWN_DICT,
52+
UNKNOWN_ENUM_VALUE,
5053
)
5154
from ads.aqua.data import AquaResourceIdentifier
5255
from ads.aqua.model import AquaModelApp
@@ -873,21 +876,26 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
873876

874877
if oci_aqua:
875878
# skipping the AQUA model deployments that are created from model group
876-
# TODO: remove this checker after AQUA deployment is integrated with model group
877-
aqua_model_id = model_deployment.freeform_tags.get(
878-
Tags.AQUA_MODEL_ID_TAG, UNKNOWN
879-
)
880879
if (
881-
"datasciencemodelgroup" in aqua_model_id
882-
or model_deployment.model_deployment_configuration_details.deployment_type
883-
== "UNKNOWN_ENUM_VALUE"
880+
model_deployment.model_deployment_configuration_details.deployment_type
881+
in [UNKNOWN_ENUM_VALUE, MODEL_GROUP, SINGLE_MODEL_FLEX]
884882
):
885883
continue
886-
results.append(
887-
AquaDeployment.from_oci_model_deployment(
888-
model_deployment, self.region
884+
try:
885+
results.append(
886+
AquaDeployment.from_oci_model_deployment(
887+
model_deployment, self.region
888+
)
889889
)
890-
)
890+
except Exception as e:
891+
logger.error(
892+
f"There was an issue processing the list of model deployments . Error: {str(e)}",
893+
exc_info=True,
894+
)
895+
raise AquaRuntimeError(
896+
f"There was an issue processing the list of model deployments . Error: {str(e)}"
897+
) from e
898+
891899
# log telemetry if MD is in active or failed state
892900
deployment_id = model_deployment.id
893901
state = model_deployment.lifecycle_state.upper()

ads/aqua/shaperecommend/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@
9494
DEFAULT_WEIGHT_SIZE = "bfloat16"
9595
DEFAULT_MAX_SEQ_LEN = 4096
9696

97+
DEFAULT_WEIGHT_SIZE = "float32"
98+
99+
97100
BITS_AND_BYTES_8BIT = "8bit"
98101
BITS_AND_BYTES_4BIT = "4bit"
99102

ads/aqua/shaperecommend/recommend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,17 @@ def which_shapes(
111111
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
112112
llm_config, shapes, model_name
113113
)
114+
115+
data = self._get_model_config(ds_model)
114116

117+
llm_config = LLMConfig.from_raw_config(data)
118+
119+
model_name = ds_model.display_name if ds_model.display_name else ""
120+
121+
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
122+
llm_config, shapes, model_name
123+
)
124+
115125
if request.generate_table and shape_recommendation_report.recommendations:
116126
shape_recommendation_report = self._rich_diff_table(
117127
shape_recommendation_report
@@ -257,6 +267,7 @@ def _rich_diff_table(shape_report: ShapeRecommendationReport) -> Table:
257267
else:
258268
total_memory = f"CPU: {str(shape.memory_in_gbs)}"
259269

270+
260271
if model:
261272
model_size = str(model.total_model_gb)
262273
else:

ads/aqua/shaperecommend/shape_report.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
6+
67
from typing import List, Optional
78

89
from pydantic import BaseModel, Field
@@ -17,6 +18,7 @@
1718
VLLM_ENV_KEY,
1819
VLLM_PARAMS_KEY,
1920
)
21+
from ads.aqua.shaperecommend.constants import QUANT_MAPPING
2022
from ads.aqua.shaperecommend.estimator import MemoryEstimator
2123
from ads.config import COMPARTMENT_OCID
2224

@@ -56,6 +58,8 @@ class DeploymentParams(BaseModel): # noqa: N801
5658
None, description="Type of quantization (e.g. 4bit)."
5759
)
5860
max_model_len: Optional[int] = Field(None, description="Maximum length of input sequence.")
61+
max_model_len: int = Field(..., description="Maximum length of input sequence.")
62+
5963
params: str = Field(
6064
..., description="Runtime parameters for deployment with vLLM, etc."
6165
)
@@ -88,6 +92,12 @@ class ModelConfig(BaseModel):
8892

8993
recommendation: Optional[str] = Field("", description="GPU recommendation for the model.")
9094

95+
model_details: ModelDetail = Field(..., description="Details about the model.")
96+
deployment_params: DeploymentParams = Field(
97+
..., description="Parameters for deployment."
98+
)
99+
recommendation: str = Field(..., description="GPU recommendation for the model.")
100+
91101
class Config:
92102
protected_namespaces = ()
93103

@@ -246,7 +256,6 @@ class ShapeRecommendationReport(BaseModel):
246256
description="Details for troubleshooting if no shapes fit the current model.",
247257
)
248258

249-
250259
@classmethod
251260
def from_deployment_config(cls, deployment_config: AquaDeploymentConfig, model_name: str, valid_shapes: List[ComputeShapeSummary]) -> "ShapeRecommendationReport":
252261
"""

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,7 @@ def is_multi_node_job(runtime):
17511751
return (
17521752
MULTI_NODE_JOB_SUPPORT
17531753
and isinstance(runtime, MultiNodeRuntime)
1754+
and runtime.replica
17541755
and runtime.replica > 1
17551756
)
17561757

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,11 @@ def _get_node_group(self, dsc_job):
365365
dsc_job,
366366
"job_node_configuration_details.job_node_group_configuration_details_list",
367367
)
368+
if node_groups is None:
369+
node_groups = get_value(
370+
dsc_job,
371+
"job_node_configuration_details.jobNodeGroupConfigurationDetailsList",
372+
)
368373
if node_groups and len(node_groups) == 1:
369374
return node_groups[0]
370375
return None
@@ -373,6 +378,7 @@ def _get_replica(self, dsc_job, envs):
373378
node_group = self._get_node_group(dsc_job)
374379
if node_group:
375380
replica = get_value(node_group, "replicas")
381+
envs.pop(self.CONST_NODE_COUNT, None)
376382
elif not envs:
377383
replica = None
378384
elif self.CONST_WORKER_COUNT in envs:
@@ -399,7 +405,9 @@ def _extract_envs(self, dsc_job):
399405
env_attr = "job_configuration_details.environment_variables"
400406
node_group = self._get_node_group(dsc_job)
401407
if node_group:
402-
envs = get_value(node_group, env_attr)
408+
envs = get_value(node_group, env_attr) or get_value(
409+
node_group, "jobConfigurationDetails.environment_variables"
410+
)
403411
else:
404412
envs = get_value(dsc_job, env_attr)
405413
if envs:

ads/pipeline/ads_pipeline.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,15 +1728,19 @@ def __step_details(self, pipeline_details: Dict) -> list:
17281728

17291729
def __step_infrastructure_configuration_details(self, step) -> dict:
17301730
step_infrastructure_configuration_details = {}
1731-
step_infrastructure_configuration_details[
1732-
"blockStorageSizeInGBs"
1733-
] = step.infrastructure.block_storage_size
1734-
step_infrastructure_configuration_details[
1735-
"shapeName"
1736-
] = step.infrastructure.shape_name
1737-
step_infrastructure_configuration_details[
1738-
"shapeConfigDetails"
1739-
] = step.infrastructure.shape_config_details
1731+
step_infrastructure_configuration_details["blockStorageSizeInGBs"] = (
1732+
step.infrastructure.block_storage_size
1733+
)
1734+
step_infrastructure_configuration_details["shapeName"] = (
1735+
step.infrastructure.shape_name
1736+
)
1737+
step_infrastructure_configuration_details["shapeConfigDetails"] = (
1738+
step.infrastructure.shape_config_details
1739+
)
1740+
if getattr(step.infrastructure, "subnet_id", ""):
1741+
step_infrastructure_configuration_details["subnetId"] = (
1742+
step.infrastructure.subnet_id
1743+
)
17401744
return step_infrastructure_configuration_details
17411745

17421746
def __step_configuration_details(self, pipeline_details: Dict, step) -> dict:

tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from unittest import mock
1111

1212
from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime
13+
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
14+
MULTI_NODE_JOB_SUPPORT,
15+
)
1316
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
1417
PyTorchDistributedRuntimeHandler as Handler,
1518
)
@@ -133,23 +136,26 @@ def test_create_job_runs(self, patched_run, *args):
133136
self.assertIsInstance(main_run, DataScienceJobRun)
134137
self.assertEqual(main_run.id, test_ocid)
135138
kwarg_list = [call_args.kwargs for call_args in patched_run.call_args_list]
136-
self.assertEqual(
137-
kwarg_list,
138-
[
139-
{
140-
"display_name": "None-0",
141-
"environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"},
142-
},
143-
{
144-
"display_name": "None-1",
145-
"environment_variables": {
146-
"NODE_RANK": "1",
147-
"NODE_COUNT": "2",
148-
"MAIN_JOB_RUN_OCID": test_ocid,
139+
if MULTI_NODE_JOB_SUPPORT:
140+
self.assertEqual(kwarg_list, [{}])
141+
else:
142+
self.assertEqual(
143+
kwarg_list,
144+
[
145+
{
146+
"display_name": "None-0",
147+
"environment_variables": {"NODE_RANK": "0", "NODE_COUNT": "2"},
149148
},
150-
},
151-
],
152-
)
149+
{
150+
"display_name": "None-1",
151+
"environment_variables": {
152+
"NODE_RANK": "1",
153+
"NODE_COUNT": "2",
154+
"MAIN_JOB_RUN_OCID": test_ocid,
155+
},
156+
},
157+
],
158+
)
153159

154160
@mock.patch.dict(
155161
os.environ, {utils.CONST_ENV_INPUT_MAPPINGS: json.dumps({INPUT_SRC: INPUT_DST})}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
{
2+
"architectures": [
3+
"MistralForCausalLM"
4+
],
5+
"attention_dropout": 0.0,
6+
"bos_token_id": 1,
7+
"eos_token_id": 2,
8+
"pad_token_id": 11,
9+
"head_dim": 128,
10+
"hidden_act": "silu",
11+
"hidden_size": 5120,
12+
"initializer_range": 0.02,
13+
"intermediate_size": 32768,
14+
"max_position_embeddings": 131072,
15+
"model_type": "mistral",
16+
"num_attention_heads": 32,
17+
"num_hidden_layers": 40,
18+
"num_key_value_heads": 8,
19+
"rms_norm_eps": 1e-05,
20+
"rope_theta": 1000000000.0,
21+
"sliding_window": null,
22+
"tie_word_embeddings": false,
23+
"torch_dtype": "bfloat16",
24+
"transformers_version": "4.53.1",
25+
"use_cache": true,
26+
"vocab_size": 131072
27+
}

0 commit comments

Comments
 (0)