@@ -2306,60 +2306,64 @@ def test_validate_multimodel_deployment_feasibility_positive_single(
23062306 "test_data/deployment/aqua_summary_multi_model_single.json" ,
23072307 )
23082308
2309-
23102309class TestBaseModelSpec :
23112310 VALID_WEIGHT = LoraModuleSpec (
23122311 model_name = "ft_model" ,
23132312 model_path = "oci://test_bucket@test_namespace/" ,
23142313 )
23152314
23162315 @pytest .mark .parametrize (
2317- "model_path, ft_weights, expect_warning" ,
2316+ "model_path, ft_weights, expect_warning, expect_error " ,
23182317 [
2319- ("oci://test_location_3" , [VALID_WEIGHT , VALID_WEIGHT ], True ),
2320- ("oci://test_location_3" , [], False ),
2321- ("not-a-valid-uri" , [VALID_WEIGHT ], False ),
2318+ ("oci://test_location_3" , [VALID_WEIGHT , VALID_WEIGHT ], True , False ),
2319+ ("oci://test_location_3" , [], False , False ),
2320+ ("not-a-valid-uri" , [VALID_WEIGHT ], False , True ),
23222321 ],
23232322 )
2324- def test_invalid_base_model_spec (
2323+ def test_invalid_from_aqua_multi_model_ref (
23252324 self ,
23262325 model_path ,
23272326 ft_weights ,
23282327 expect_warning ,
2328+ expect_error ,
23292329 caplog ,
23302330 ):
23312331 logger = logging .getLogger ("ads.aqua.modeldeployment.model_group_config" )
23322332 logger .propagate = True
23332333
23342334 caplog .set_level (logging .WARNING , logger = logger .name )
23352335
2336- with pytest .raises (ValidationError ) as excinfo :
2337- BaseModelSpec (
2338- model_id = "test_model_id_3" ,
2339- model_name = "test_model_3" ,
2340- model_task = "code_synthesis" ,
2341- model_path = model_path ,
2342- fine_tune_weights = ft_weights ,
2343- )
2336+ model_ref = AquaMultiModelRef (
2337+ artifact_location = model_path ,
2338+ model_task = "code_synthesis" ,
2339+ model_name = "test_model_3" ,
2340+ model_id = "test_model_id_3" ,
2341+ fine_tune_weights = ft_weights ,
2342+ env_var = {},
2343+ gpu_count = 1 ,
2344+ )
2345+
2346+ model_params = "--dummy-param"
2347+
2348+ if expect_error :
2349+ with pytest .raises (ValidationError ) as excinfo :
2350+ BaseModelSpec .from_aqua_multi_model_ref (model_ref , model_params )
2351+ errs = excinfo .value .errors ()
2352+ if not model_path .startswith ("oci://" ):
2353+ model_path_errors = [e for e in errs if e ["loc" ] == ("model_path" ,)]
2354+ assert model_path_errors , f"expected a model_path error, got: { errs !r} "
2355+ assert (
2356+ "the base model path is not available in the model artifact."
2357+ in model_path_errors [0 ]["msg" ].lower ()
2358+ )
2359+ else :
2360+ BaseModelSpec .from_aqua_multi_model_ref (model_ref , model_params )
23442361
23452362 messages = [rec .getMessage ().lower () for rec in caplog .records ]
2346-
23472363 if expect_warning :
23482364 assert any (
23492365 "duplicate lora modules detected" in m for m in messages
23502366 ), f"Expected warning, got none. Captured messages: { messages } "
23512367 else :
23522368 assert not messages , f"Did not expect any warnings, but got: { messages } "
23532369
2354- # inspecting if errors are thrown
2355- errs = excinfo .value .errors ()
2356- if not model_path .startswith ("oci://" ):
2357- model_path_errors = [e for e in errs if e ["loc" ] == ("model_path" ,)]
2358- assert model_path_errors , f"expected a model_path error, got: { errs !r} "
2359- assert (
2360- "the base model path is not available in the model artifact."
2361- in model_path_errors [0 ]["msg" ].lower ()
2362- )
2363- else :
2364- # e.g. for the duplicate‐weights case you might check for a different loc/msg
2365- pass
0 commit comments