Skip to content

Commit e224945

Browse files
committed
patch to fix missing fine_tune_weights in MULTI_MODEL_CONFIG
1 parent 5511fc8 commit e224945

File tree

2 files changed

+38
-39
lines changed

2 files changed

+38
-39
lines changed

ads/aqua/modeldeployment/model_group_config.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class BaseModelSpec(BaseModel):
4444
Additional vLLM launch parameters for this model (e.g. parallelism, max context).
4545
model_task : str, optional
4646
Model task type (e.g., text-generation, image-to-text).
47-
fine_tune_weights : List[FineTunedModelSpec], optional
48-
List of associated fine-tuned models.
47+
fine_tune_weights : List[List[LoraModuleSpec]], optional
48+
List of associated LoRA modules for fine-tuned models.
4949
"""
5050

5151
model_path: str = Field(..., description="Path to the base model.")
@@ -71,17 +71,13 @@ def clean_model_path(cls, artifact_path_prefix: str) -> str:
7171
"The base model path is not available in the model artifact."
7272
)
7373

74-
@field_validator("fine_tune_weights")
7574
@classmethod
76-
def set_fine_tuned_weights(cls, fine_tuned_weights: List[LoraModuleSpec]):
77-
"""Removes duplicate LoRA Modules (duplicate model_names in fine_tuned_weights)"""
75+
def dedup_lora_modules(cls, fine_tune_weights: List[LoraModuleSpec]):
76+
"""Removes duplicate LoRA Modules (duplicate model_names in fine_tune_weights)"""
7877
seen_modules = set()
7978
unique_modules: List[LoraModuleSpec] = []
8079

81-
if not fine_tuned_weights:
82-
return None
83-
84-
for lora_module in fine_tuned_weights:
80+
for lora_module in fine_tune_weights or []:
8581
if lora_module.model_name not in seen_modules:
8682
seen_modules.add(lora_module.model_name)
8783
unique_modules.append(lora_module)
@@ -101,7 +97,7 @@ def from_aqua_multi_model_ref(
10197
model_path=model.artifact_location,
10298
params=model_params,
10399
model_task=model.model_task,
104-
fine_tuned_weights=model.fine_tune_weights,
100+
fine_tune_weights=cls.dedup_lora_modules(model.fine_tune_weights),
105101
)
106102

107103

@@ -112,7 +108,7 @@ class ModelGroupConfig(Serializable):
112108
Attributes
113109
----------
114110
models : List[BaseModelConfig]
115-
List of base models (with optional fine-tuned weights) to be served.
111+
List of base models (with optional fine-tune weights) to be served.
116112
"""
117113

118114
models: List[BaseModelSpec] = Field(
@@ -228,5 +224,4 @@ def from_create_model_deployment_details(
228224
"Each base model must have a unique `model_name`. "
229225
"Please remove or rename the duplicate model and register the model group again."
230226
)
231-
232227
return cls(models=models)

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,60 +2306,64 @@ def test_validate_multimodel_deployment_feasibility_positive_single(
23062306
"test_data/deployment/aqua_summary_multi_model_single.json",
23072307
)
23082308

2309-
23102309
class TestBaseModelSpec:
23112310
VALID_WEIGHT = LoraModuleSpec(
23122311
model_name="ft_model",
23132312
model_path="oci://test_bucket@test_namespace/",
23142313
)
23152314

23162315
@pytest.mark.parametrize(
2317-
"model_path, ft_weights, expect_warning",
2316+
"model_path, ft_weights, expect_warning, expect_error",
23182317
[
2319-
("oci://test_location_3", [VALID_WEIGHT, VALID_WEIGHT], True),
2320-
("oci://test_location_3", [], False),
2321-
("not-a-valid-uri", [VALID_WEIGHT], False),
2318+
("oci://test_location_3", [VALID_WEIGHT, VALID_WEIGHT], True, False),
2319+
("oci://test_location_3", [], False, False),
2320+
("not-a-valid-uri", [VALID_WEIGHT], False, True),
23222321
],
23232322
)
2324-
def test_invalid_base_model_spec(
2323+
def test_invalid_from_aqua_multi_model_ref(
23252324
self,
23262325
model_path,
23272326
ft_weights,
23282327
expect_warning,
2328+
expect_error,
23292329
caplog,
23302330
):
23312331
logger = logging.getLogger("ads.aqua.modeldeployment.model_group_config")
23322332
logger.propagate = True
23332333

23342334
caplog.set_level(logging.WARNING, logger=logger.name)
23352335

2336-
with pytest.raises(ValidationError) as excinfo:
2337-
BaseModelSpec(
2338-
model_id="test_model_id_3",
2339-
model_name="test_model_3",
2340-
model_task="code_synthesis",
2341-
model_path=model_path,
2342-
fine_tune_weights=ft_weights,
2343-
)
2336+
model_ref = AquaMultiModelRef(
2337+
artifact_location=model_path,
2338+
model_task="code_synthesis",
2339+
model_name="test_model_3",
2340+
model_id="test_model_id_3",
2341+
fine_tune_weights=ft_weights,
2342+
env_var={},
2343+
gpu_count=1,
2344+
)
2345+
2346+
model_params = "--dummy-param"
2347+
2348+
if expect_error:
2349+
with pytest.raises(ValidationError) as excinfo:
2350+
BaseModelSpec.from_aqua_multi_model_ref(model_ref, model_params)
2351+
errs = excinfo.value.errors()
2352+
if not model_path.startswith("oci://"):
2353+
model_path_errors = [e for e in errs if e["loc"] == ("model_path",)]
2354+
assert model_path_errors, f"expected a model_path error, got: {errs!r}"
2355+
assert (
2356+
"the base model path is not available in the model artifact."
2357+
in model_path_errors[0]["msg"].lower()
2358+
)
2359+
else:
2360+
BaseModelSpec.from_aqua_multi_model_ref(model_ref, model_params)
23442361

23452362
messages = [rec.getMessage().lower() for rec in caplog.records]
2346-
23472363
if expect_warning:
23482364
assert any(
23492365
"duplicate lora modules detected" in m for m in messages
23502366
), f"Expected warning, got none. Captured messages: {messages}"
23512367
else:
23522368
assert not messages, f"Did not expect any warnings, but got: {messages}"
23532369

2354-
# inspecting if errors are thrown
2355-
errs = excinfo.value.errors()
2356-
if not model_path.startswith("oci://"):
2357-
model_path_errors = [e for e in errs if e["loc"] == ("model_path",)]
2358-
assert model_path_errors, f"expected a model_path error, got: {errs!r}"
2359-
assert (
2360-
"the base model path is not available in the model artifact."
2361-
in model_path_errors[0]["msg"].lower()
2362-
)
2363-
else:
2364-
# e.g. for the duplicate‐weights case you might check for a different loc/msg
2365-
pass

0 commit comments

Comments
 (0)