Skip to content

Commit e94b6f1

Browse files
committed
fixed unit tests
1 parent a17b035 commit e94b6f1

File tree

7 files changed

+825
-132
lines changed

7 files changed

+825
-132
lines changed

ads/aqua/common/utils.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,24 +1253,24 @@ def load_gpu_shapes_index(
12531253
file_name = "gpu_shapes_index.json"
12541254

12551255
# Try remote load
1256-
# remote_data: Dict[str, Any] = {}
1257-
# if CONDA_BUCKET_NS:
1258-
# try:
1259-
# auth = auth or authutil.default_signer()
1260-
# storage_path = (
1261-
# f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}"
1262-
# )
1263-
# logger.debug(
1264-
# "Loading GPU shapes index from Object Storage: %s", storage_path
1265-
# )
1266-
# with fsspec.open(storage_path, mode="r", **auth) as f:
1267-
# remote_data = json.load(f)
1268-
# logger.debug(
1269-
# "Loaded %d shapes from Object Storage",
1270-
# len(remote_data.get("shapes", {})),
1271-
# )
1272-
# except Exception as ex:
1273-
# logger.debug("Remote load failed (%s); falling back to local", ex)
1256+
remote_data: Dict[str, Any] = {}
1257+
if CONDA_BUCKET_NS:
1258+
try:
1259+
auth = auth or authutil.default_signer()
1260+
storage_path = (
1261+
f"oci://{CONDA_BUCKET_NAME}@{CONDA_BUCKET_NS}/service_pack/{file_name}"
1262+
)
1263+
logger.debug(
1264+
"Loading GPU shapes index from Object Storage: %s", storage_path
1265+
)
1266+
with fsspec.open(storage_path, mode="r", **auth) as f:
1267+
remote_data = json.load(f)
1268+
logger.debug(
1269+
"Loaded %d shapes from Object Storage",
1270+
len(remote_data.get("shapes", {})),
1271+
)
1272+
except Exception as ex:
1273+
logger.debug("Remote load failed (%s); falling back to local", ex)
12741274

12751275
# Load local copy
12761276
local_data: Dict[str, Any] = {}

ads/aqua/shaperecommend/recommend.py

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def which_gpu(self, **kwargs) -> ShapeRecommendationReport:
6161
6262
Parameters
6363
----------
64-
model_ocid : str
64+
ocid : str
6565
OCID of the model to recommend feasible compute shapes.
6666
6767
Returns
@@ -77,19 +77,23 @@ def which_gpu(self, **kwargs) -> ShapeRecommendationReport:
7777
"""
7878
try:
7979
request = RequestRecommend(**kwargs)
80-
data, model_name = self.get_model_config(request.model_ocid)
80+
ds_model = self.validate_model_ocid(request.model_ocid)
81+
data = self.get_model_config(ds_model)
8182

8283
llm_config = LLMConfig.from_raw_config(data)
8384

8485
available_shapes = self.valid_compute_shapes()
86+
87+
model_name = ds_model.display_name if ds_model.display_name else ""
88+
8589
recommendations = self.summarize_shapes_for_seq_lens(
8690
llm_config, available_shapes, model_name
8791
)
8892

8993
# custom error to catch model incompatibility issues
9094
except AquaRecommendationError as error:
9195
return ShapeRecommendationReport(
92-
recommendations=[], troubleshoot=str(error)
96+
recommendations=[], troubleshoot=str(error)
9397
)
9498

9599
except ValidationError as ex:
@@ -115,10 +119,16 @@ def rich_diff_table(shape_report: ShapeRecommendationReport) -> Table:
115119
Returns:
116120
Table: A rich Table displaying model deployment recommendations.
117121
"""
118-
logger.debug("Starting to generate rich diff table from ShapeRecommendationReport.")
122+
logger.debug(
123+
"Starting to generate rich diff table from ShapeRecommendationReport."
124+
)
119125

120-
name = shape_report.model_name
121-
header = f"Model Deployment Recommendations: {name}" if name else "Model Deployment Recommendations"
126+
name = shape_report.display_name
127+
header = (
128+
f"Model Deployment Recommendations: {name}"
129+
if name
130+
else "Model Deployment Recommendations"
131+
)
122132
logger.debug(f"Table header set to: {header!r}")
123133

124134
if shape_report.troubleshoot:
@@ -167,13 +177,12 @@ def rich_diff_table(shape_report: ShapeRecommendationReport) -> Table:
167177
str(model.total_model_gb),
168178
deploy.quantization,
169179
str(deploy.max_model_len),
170-
full_recommendation
180+
full_recommendation,
171181
)
172182

173183
logger.debug("Completed populating table with recommendation rows.")
174184
return table
175185

176-
177186
def shapes(self, **kwargs) -> Table:
178187
"""
179188
For the CLI, generates the table (in rich diff) with valid GPU deployment shapes
@@ -203,12 +212,31 @@ def shapes(self, **kwargs) -> Table:
203212
if shape_recommend_report.troubleshoot:
204213
raise AquaValueError(shape_recommend_report.troubleshoot)
205214
else:
206-
raise AquaValueError("Unable to generate recommendations from model. Please ensure model is registered and is a decoder-only text-generation model.")
215+
raise AquaValueError(
216+
"Unable to generate recommendations from model. Please ensure model is registered and is a decoder-only text-generation model."
217+
)
207218

208219
return self.rich_diff_table(shape_recommend_report)
209220

210221
@staticmethod
211-
def get_model_config(ocid: str):
222+
def validate_model_ocid(ocid: str) -> DataScienceModel:
223+
"""
224+
Ensures the OCID passed is valid for referencing a DataScienceModel resource.
225+
"""
226+
resource_type = get_resource_type(ocid)
227+
228+
if resource_type != "datasciencemodel":
229+
raise AquaValueError(
230+
f"The provided OCID '{ocid}' is not a valid Oracle Cloud Data Science Model OCID. "
231+
"Please provide an OCID corresponding to a Data Science model resource. "
232+
"Tip: Data Science model OCIDs typically start with 'ocid1.datasciencemodel...'."
233+
)
234+
235+
model = DataScienceModel.from_id(ocid)
236+
return model
237+
238+
@staticmethod
239+
def get_model_config(model: DataScienceModel):
212240
"""
213241
Loads the configuration for a given Oracle Cloud Data Science model.
214242
@@ -218,8 +246,8 @@ def get_model_config(ocid: str):
218246
219247
Parameters
220248
----------
221-
ocid : str
222-
The OCID of the Data Science model.
249+
model : DataScienceModel
250+
The DataScienceModel representation of the model used in recommendations
223251
224252
Returns
225253
-------
@@ -235,18 +263,6 @@ def get_model_config(ocid: str):
235263
AquaRecommendationError
236264
If the model OCID provided is not supported (only text-generation decoder models in safetensor format supported).
237265
"""
238-
resource_type = get_resource_type(ocid)
239-
240-
if resource_type != "datasciencemodel":
241-
raise AquaValueError(
242-
f"The provided OCID '{ocid}' is not a valid Oracle Cloud Data Science Model OCID. "
243-
"Please provide an OCID corresponding to a Data Science model resource. "
244-
"Tip: Data Science model OCIDs typically start with 'ocid1.datasciencemodel...'."
245-
)
246-
247-
model = DataScienceModel.from_id(ocid)
248-
249-
model_name = model.display_name
250266

251267
model_task = model.freeform_tags.get("task", "").lower()
252268
model_format = model.freeform_tags.get("model_format", "").lower()
@@ -283,7 +299,7 @@ def get_model_config(ocid: str):
283299
"Please ensure your model follows the Hugging Face format and includes a 'config.json' with the necessary architecture parameters."
284300
) from e
285301

286-
return data, model_name
302+
return data
287303

288304
@staticmethod
289305
def valid_compute_shapes() -> List["ComputeShapeSummary"]:
@@ -444,5 +460,7 @@ def summarize_shapes_for_seq_lens(
444460
)
445461

446462
return ShapeRecommendationReport(
447-
model_name=name, recommendations=recommendations, troubleshoot=troubleshoot_msg
463+
display_name=name,
464+
recommendations=recommendations,
465+
troubleshoot=troubleshoot_msg,
448466
)

ads/aqua/shaperecommend/shape_report.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@ class RequestRecommend(BaseModel):
1616
A request to recommend compute shapes and parameters for a given model.
1717
"""
1818

19-
model_ocid: str = Field(..., description="The OCID of the model to recommend feasible compute shapes.")
19+
model_ocid: str = Field(
20+
..., description="The OCID of the model to recommend feasible compute shapes."
21+
)
22+
23+
class Config:
24+
protected_namespaces = ()
2025

2126

2227
class DeploymentParams(BaseModel): # noqa: N801
@@ -42,6 +47,9 @@ class ModelDetail(BaseModel):
4247
..., description="Total size of model and cache in GB."
4348
)
4449

50+
class Config:
51+
protected_namespaces = ()
52+
4553

4654
class ModelConfig(BaseModel):
4755
"""
@@ -54,8 +62,13 @@ class ModelConfig(BaseModel):
5462
)
5563
recommendation: str = Field(..., description="GPU recommendation for the model.")
5664

65+
class Config:
66+
protected_namespaces = ()
67+
5768
@classmethod
58-
def constuct_model_config(cls, estimator: MemoryEstimator, allowed_gpu_memory: float) -> "ModelConfig":
69+
def constuct_model_config(
70+
cls, estimator: MemoryEstimator, allowed_gpu_memory: float
71+
) -> "ModelConfig":
5972
"""
6073
Assembles a complete ModelConfig, including model details, deployment parameters (vLLM), and recommendations.
6174
@@ -78,32 +91,33 @@ def constuct_model_config(cls, estimator: MemoryEstimator, allowed_gpu_memory: f
7891
"""
7992
deployment_params = DeploymentParams(
8093
quantization=getattr(estimator.llm_config, "quantization", None),
81-
max_model_len=getattr(estimator, "seq_len", None)
94+
max_model_len=getattr(estimator, "seq_len", None),
8295
)
8396
model_detail = ModelDetail(
8497
model_size_gb=round(getattr(estimator, "model_memory", 0.0), 2),
8598
kv_cache_size_gb=round(getattr(estimator, "kv_cache_memory", 0.0), 2),
86-
total_model_gb=round(getattr(estimator, "total_memory", 0.0), 2)
99+
total_model_gb=round(getattr(estimator, "total_memory", 0.0), 2),
87100
)
88101
return ModelConfig(
89102
model_details=model_detail,
90103
deployment_params=deployment_params,
91-
recommendation= estimator.limiting_factor(allowed_gpu_memory)
104+
recommendation=estimator.limiting_factor(allowed_gpu_memory),
92105
)
93106

94107

95108
class ShapeReport(BaseModel):
96109
"""
97110
The feasible deployment configurations for the model per shape.
98111
"""
99-
shape_details: 'ComputeShapeSummary' = Field(
112+
113+
shape_details: "ComputeShapeSummary" = Field(
100114
..., description="Details about the compute shape (ex. VM.GPU.A10.2)."
101115
)
102-
configurations: List['ModelConfig'] = Field(
116+
configurations: List["ModelConfig"] = Field(
103117
default_factory=list, description="List of model configurations."
104118
)
105119

106-
def is_dominated(self, others: List['ShapeReport']) -> bool:
120+
def is_dominated(self, others: List["ShapeReport"]) -> bool:
107121
"""
108122
Determines whether this shape is dominated by any other shape in a Pareto sense.
109123
@@ -128,31 +142,35 @@ def is_dominated(self, others: List['ShapeReport']) -> bool:
128142

129143
cand_cost = self.shape_details.gpu_specs.ranking.cost
130144
cand_perf = self.shape_details.gpu_specs.ranking.performance
131-
cand_quant = QUANT_MAPPING.get(self.configurations[0].deployment_params.quantization, 0)
145+
cand_quant = QUANT_MAPPING.get(
146+
self.configurations[0].deployment_params.quantization, 0
147+
)
132148
cand_maxlen = self.configurations[0].deployment_params.max_model_len
133149

134150
for other in others:
135151
other_cost = other.shape_details.gpu_specs.ranking.cost
136152
other_perf = other.shape_details.gpu_specs.ranking.performance
137-
other_quant = QUANT_MAPPING.get(other.configurations[0].deployment_params.quantization, 0)
153+
other_quant = QUANT_MAPPING.get(
154+
other.configurations[0].deployment_params.quantization, 0
155+
)
138156
other_maxlen = other.configurations[0].deployment_params.max_model_len
139157
if (
140-
other_cost <= cand_cost and
141-
other_perf >= cand_perf and
142-
other_quant >= cand_quant and
143-
other_maxlen >= cand_maxlen and
144-
(
145-
other_cost < cand_cost or
146-
other_perf > cand_perf or
147-
other_quant > cand_quant or
148-
other_maxlen > cand_maxlen
158+
other_cost <= cand_cost
159+
and other_perf >= cand_perf
160+
and other_quant >= cand_quant
161+
and other_maxlen >= cand_maxlen
162+
and (
163+
other_cost < cand_cost
164+
or other_perf > cand_perf
165+
or other_quant > cand_quant
166+
or other_maxlen > cand_maxlen
149167
)
150168
):
151169
return True
152170
return False
153171

154172
@classmethod
155-
def pareto_front(cls, shapes: List['ShapeReport']) -> List['ShapeReport']:
173+
def pareto_front(cls, shapes: List["ShapeReport"]) -> List["ShapeReport"]:
156174
"""
157175
Filters a list of shapes/configurations to those on the Pareto frontier.
158176
@@ -171,7 +189,11 @@ def pareto_front(cls, shapes: List['ShapeReport']) -> List['ShapeReport']:
171189
The returned set contains non-dominated deployments for maximizing
172190
performance, quantization, and model length, while minimizing cost.
173191
"""
174-
return [shape for shape in shapes if not shape.is_dominated([s for s in shapes if s != shape])]
192+
return [
193+
shape
194+
for shape in shapes
195+
if not shape.is_dominated([s for s in shapes if s != shape])
196+
]
175197

176198

177199
class ShapeRecommendationReport(BaseModel):
@@ -184,7 +206,8 @@ class ShapeRecommendationReport(BaseModel):
184206
troubleshoot (Optional[TroubleshootShapeSummary]): Troubleshooting information
185207
if no valid deployment shapes are available.
186208
"""
187-
model_name: Optional[str] = Field(
209+
210+
display_name: Optional[str] = Field(
188211
"", description="Name of the model used for recommendations."
189212
)
190213
recommendations: List[ShapeReport] = Field(

0 commit comments

Comments
 (0)