@@ -99,18 +99,23 @@ def which_shapes(
9999 try :
100100 shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
101101
102- data , model_name = self ._get_model_config_and_name (
103- model_id = request .model_id ,
104- )
105-
106102 if request .deployment_config :
103+ if is_valid_ocid (request .model_id ):
104+ ds_model = self ._get_data_science_model (request .model_id )
105+ model_name = ds_model .display_name
106+ else :
107+ model_name = request .model_id
108+
107109 shape_recommendation_report = (
108110 ShapeRecommendationReport .from_deployment_config (
109111 request .deployment_config , model_name , shapes
110112 )
111113 )
112114
113115 else :
116+ data , model_name = self ._get_model_config_and_name (
117+ model_id = request .model_id ,
118+ )
114119 llm_config = LLMConfig .from_raw_config (data )
115120
116121 shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
@@ -394,7 +399,7 @@ def _get_model_config(model: DataScienceModel):
394399 """
395400
396401 model_task = model .freeform_tags .get ("task" , "" ).lower ()
397- model_task = re .sub (r"-" , "_" , model_task )
402+ model_task = re .sub (r"-" , "_" , model_task )
398403 model_format = model .freeform_tags .get ("model_format" , "" ).lower ()
399404
400405 logger .info (f"Current model task type: { model_task } " )
0 commit comments