44
55import json
66import os
7+ import re
78import shutil
89from typing import Dict , List , Optional , Tuple , Union
910
10- from pydantic import ValidationError
11- from rich .table import Table
12-
1311from huggingface_hub import hf_hub_download
1412from huggingface_hub .utils import HfHubHTTPError
13+ from pydantic import ValidationError
14+ from rich .table import Table
1515
1616from ads .aqua .app import logger
1717from ads .aqua .common .entities import ComputeShapeSummary
2222)
2323from ads .aqua .common .utils import (
2424 build_pydantic_error_message ,
25+ format_hf_custom_error_message ,
2526 get_resource_type ,
2627 is_valid_ocid ,
2728 load_config ,
2829 load_gpu_shapes_index ,
29- format_hf_custom_error_message ,
3030)
3131from ads .aqua .shaperecommend .constants import (
32- BITSANDBYTES ,
3332 BITS_AND_BYTES_4BIT ,
33+ BITSANDBYTES ,
3434 SAFETENSORS ,
3535 SHAPE_MAP ,
3636 TEXT_GENERATION ,
@@ -98,14 +98,10 @@ def which_shapes(
9898 """
9999 try :
100100 shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
101-
101+
102102 data , model_name = self ._get_model_config_and_name (
103103 model_id = request .model_id ,
104104 )
105- llm_config = LLMConfig .from_raw_config (data )
106- shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
107- llm_config , shapes , model_name
108- )
109105
110106 if request .deployment_config :
111107 shape_recommendation_report = (
@@ -115,16 +111,11 @@ def which_shapes(
115111 )
116112
117113 else :
118- ds_model = self ._get_data_science_model (request .model_id )
119-
120- data = self ._get_model_config (ds_model )
121-
122114 llm_config = LLMConfig .from_raw_config (data )
123115
124116 shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
125117 llm_config , shapes , model_name
126118 )
127-
128119
129120 if request .generate_table and shape_recommendation_report .recommendations :
130121 shape_recommendation_report = self ._rich_diff_table (
@@ -174,8 +165,8 @@ def _get_model_config_and_name(
174165 """
175166 if is_valid_ocid (model_id ):
176167 logger .info (f"Detected OCID: Fetching OCI model config for '{ model_id } '." )
177- ds_model = self ._validate_model_ocid (model_id )
178- config = self ._fetch_hf_config ( model_id )
168+ ds_model = self ._get_data_science_model (model_id )
169+ config = self ._get_model_config ( ds_model )
179170 model_name = ds_model .display_name
180171 else :
181172 logger .info (
@@ -403,6 +394,7 @@ def _get_model_config(model: DataScienceModel):
403394 """
404395
405396 model_task = model .freeform_tags .get ("task" , "" ).lower ()
397+ model_task = re .sub (r"-" , "_" , model_task )
406398 model_format = model .freeform_tags .get ("model_format" , "" ).lower ()
407399
408400 logger .info (f"Current model task type: { model_task } " )
0 commit comments