2020 ModelConfig ,
2121 ModelDetail ,
2222 RequestRecommend ,
23+ ShapeRecommendationReport ,
2324 ShapeReport ,
2425)
2526from ads .model .model_metadata import ModelCustomMetadata , ModelProvenanceMetadata
@@ -233,9 +234,10 @@ def __init__(self):
233234 local_shapes = local_data .get ("shapes" , {})
234235 self .shapes = local_shapes
235236
237+
236238class MockDataScienceModel :
237239 @staticmethod
238- def create (config_file = "" ):
240+ def create (config_file = "" ):
239241 mock_model = MagicMock ()
240242 mock_model .model_file_description = {"test_key" : "test_value" }
241243 mock_model .display_name = re .sub (r"\.json$" , "" , config_file )
@@ -245,7 +247,7 @@ def create(config_file = ""):
245247 "license" : "test_license" ,
246248 "organization" : "test_organization" ,
247249 "task" : "text-generation" ,
248- "model_format" : "SAFETENSORS" ,
250+ "model_format" : "SAFETENSORS" ,
249251 "ready_to_fine_tune" : "true" ,
250252 "aqua_custom_base_model" : "true" ,
251253 }
@@ -261,36 +263,68 @@ def create(config_file = ""):
261263
262264
263265class TestAquaShapeRecommend :
264-
265- def test_which_gpu_valid (self , monkeypatch , ** kwargs ):
266+ @pytest .mark .parametrize (
267+ "config, expected_recs, expected_troubleshoot" ,
268+ [
269+ ( # decoder-only model
270+ {
271+ "num_hidden_layers" : 2 ,
272+ "hidden_size" : 64 ,
273+ "vocab_size" : 1000 ,
274+ "num_attention_heads" : 4 ,
275+ "head_dim" : 16 ,
276+ "max_position_embeddings" : 2048 ,
277+ },
278+ [],
279+ "" ,
280+ ),
281+ ( # encoder-decoder model
282+ {
283+ "num_hidden_layers" : 2 ,
284+ "hidden_size" : 64 ,
285+ "vocab_size" : 1000 ,
286+ "num_attention_heads" : 4 ,
287+ "head_dim" : 16 ,
288+ "max_position_embeddings" : 2048 ,
289+ "is_encoder_decoder" : True ,
290+ },
291+ [],
292+ "Please provide a decoder-only text-generation model (ex. Llama, Falcon, etc). Encoder-decoder models (ex. T5, Gemma) and encoder-only (BERT) are not supported at this time." ,
293+ ),
294+ ],
295+ )
296+ def test_which_shapes_valid (
297+ self , monkeypatch , config , expected_recs , expected_troubleshoot
298+ ):
266299 app = AquaShapeRecommend ()
267300 mock_model = MockDataScienceModel .create ()
268301
269302 monkeypatch .setattr (
270- "ads.aqua.app.DataScienceModel.from_id" ,
271- lambda _ : mock_model
303+ "ads.aqua.app.DataScienceModel.from_id" , lambda _ : mock_model
272304 )
273305
274- config = {
275- "num_hidden_layers" : 2 ,
276- "hidden_size" : 64 ,
277- "vocab_size" : 1000 ,
278- "num_attention_heads" : 4 ,
279- "head_dim" : 16 ,
280- "max_position_embeddings" : 2048 ,
281- }
282-
306+ expected_result = ShapeRecommendationReport (
307+ recommendations = expected_recs , troubleshoot = expected_troubleshoot
308+ )
283309 app ._get_model_config = MagicMock (return_value = config )
284310 app .valid_compute_shapes = MagicMock (return_value = [])
285- app ._summarize_shapes_for_seq_lens = MagicMock (return_value = "mocked_report" )
311+ app ._summarize_shapes_for_seq_lens = MagicMock (return_value = expected_result )
286312
287- request = RequestRecommend (model_id = "ocid1.datasciencemodel.oc1.TEST" )
313+ request = RequestRecommend (
314+ model_id = "ocid1.datasciencemodel.oc1.TEST" , generate_table = False
315+ )
288316 result = app .which_shapes (request )
317+ assert result == expected_result
289318
290- app .valid_compute_shapes .assert_called_once ()
291- llm_config = LLMConfig .from_raw_config (config )
292- app ._summarize_shapes_for_seq_lens .assert_called_once_with (llm_config , [], "" )
293- assert result == "mocked_report"
319+ # If troubleshoot is populated (error case), _summarize_shapes_for_seq_lens should not have been called
320+ if expected_troubleshoot :
321+ app ._summarize_shapes_for_seq_lens .assert_not_called ()
322+ else :
323+ # For non-error case, summarize should have been called
324+ llm_config = LLMConfig .from_raw_config (config )
325+ app ._summarize_shapes_for_seq_lens .assert_called_once_with (
326+ llm_config , [], ""
327+ )
294328
295329 @pytest .mark .parametrize (
296330 "config_file, result_file" ,
@@ -303,7 +337,9 @@ def test_which_gpu_valid(self, monkeypatch, **kwargs):
303337 ),
304338 ],
305339 )
306- def test_which_gpu_valid_from_file (self , monkeypatch , config_file , result_file , ** kwargs ):
340+ def test_which_shapes_valid_from_file (
341+ self , monkeypatch , config_file , result_file , ** kwargs
342+ ):
307343 raw = load_config (config_file )
308344 app = AquaShapeRecommend ()
309345 mock_model = MockDataScienceModel .create (config_file )
@@ -317,9 +353,14 @@ def test_which_gpu_valid_from_file(self, monkeypatch, config_file, result_file,
317353 ComputeShapeSummary (name = name , shape_series = "GPU" , gpu_specs = spec )
318354 for name , spec in shapes_index .shapes .items ()
319355 ]
320- monkeypatch .setattr (app , "valid_compute_shapes" , lambda * args , ** kwargs : real_shapes )
356+ monkeypatch .setattr (
357+ app , "valid_compute_shapes" , lambda * args , ** kwargs : real_shapes
358+ )
321359
322- result = app .which_gpu (model_ocid = "ocid1.datasciencemodel.oc1.TEST" )
360+ request = RequestRecommend (
361+ model_id = "ocid1.datasciencemodel.oc1.TEST" , generate_table = False
362+ )
363+ result = app .which_shapes (request = request )
323364
324365 expected_result = load_config (result_file )
325366 assert result .model_dump () == expected_result
@@ -349,7 +390,7 @@ def test_shape_report_pareto_front(self):
349390 model_size_gb = 1 , kv_cache_size_gb = 1 , total_model_gb = 2
350391 ),
351392 deployment_params = DeploymentParams (
352- quantization = "8bit" , max_model_len = 2048 , params = ""
393+ quantization = "8bit" , max_model_len = 2048 , params = ""
353394 ),
354395 recommendation = "ok" ,
355396 )
@@ -363,7 +404,7 @@ def test_shape_report_pareto_front(self):
363404 model_size_gb = 1 , kv_cache_size_gb = 1 , total_model_gb = 2
364405 ),
365406 deployment_params = DeploymentParams (
366- quantization = "8bit" , max_model_len = 2048 , params = ""
407+ quantization = "8bit" , max_model_len = 2048 , params = ""
367408 ),
368409 recommendation = "ok" ,
369410 )
@@ -377,7 +418,7 @@ def test_shape_report_pareto_front(self):
377418 model_size_gb = 1 , kv_cache_size_gb = 1 , total_model_gb = 2
378419 ),
379420 deployment_params = DeploymentParams (
380- quantization = "bfloat16" , max_model_len = 2048 , params = ""
421+ quantization = "bfloat16" , max_model_len = 2048 , params = ""
381422 ),
382423 recommendation = "ok" ,
383424 )
0 commit comments