4848 OCIDataScienceModelDeployment ,
4949)
5050
51+
5152class HuggingFaceModelFetcher :
5253 """
5354 Utility class to fetch model configurations from HuggingFace.
@@ -57,7 +58,7 @@ class HuggingFaceModelFetcher:
5758 def is_huggingface_model_id (cls , model_id : str ) -> bool :
5859 if is_valid_ocid (model_id ):
5960 return False
60- hf_pattern = r' ^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$'
61+ hf_pattern = r" ^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$"
6162 return bool (re .match (hf_pattern , model_id ))
6263
6364 @classmethod
@@ -80,12 +81,19 @@ def fetch_config_only(cls, model_id: str) -> Dict[str, Any]:
8081 elif response .status_code == 404 :
8182 raise AquaValueError (f"Model '{ model_id } ' not found on HuggingFace." )
8283 elif response .status_code != 200 :
83- raise AquaValueError (f"Failed to fetch config for '{ model_id } '. Status: { response .status_code } " )
84+ raise AquaValueError (
85+ f"Failed to fetch config for '{ model_id } '. Status: { response .status_code } "
86+ )
8487 return response .json ()
8588 except requests .RequestException as e :
86- raise AquaValueError (f"Network error fetching config for { model_id } : { e } " ) from e
89+ raise AquaValueError (
90+ f"Network error fetching config for { model_id } : { e } "
91+ ) from e
8792 except json .JSONDecodeError as e :
88- raise AquaValueError (f"Invalid config format for model '{ model_id } '." ) from e
93+ raise AquaValueError (
94+ f"Invalid config format for model '{ model_id } '."
95+ ) from e
96+
8997
9098class AquaShapeRecommend :
9199 """
@@ -135,7 +143,9 @@ def which_shapes(
135143 """
136144 try :
137145 shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
138- data , model_name = self ._get_model_config_and_name (request .model_id , request .compartment_id )
146+ data , model_name = self ._get_model_config_and_name (
147+ request .model_id , request .compartment_id
148+ )
139149 llm_config = LLMConfig .from_raw_config (data )
140150 shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
141151 llm_config , shapes , model_name
@@ -165,40 +175,55 @@ def which_shapes(
165175
166176 return shape_recommendation_report
167177
168- def _get_model_config_and_name (self , model_id : str , compartment_id : str ) -> (dict , str ):
178+ def _get_model_config_and_name (
179+ self , model_id : str , compartment_id : str
180+ ) -> (dict , str ):
169181 """
170182 Loads model configuration, handling OCID and Hugging Face model IDs.
171183 """
172184 if HuggingFaceModelFetcher .is_huggingface_model_id (model_id ):
173185 logger .info (f"'{ model_id } ' identified as a Hugging Face model ID." )
174186 ds_model = self ._search_model_in_catalog (model_id , compartment_id )
175187 if ds_model and ds_model .artifact :
176- logger .info ("Loading configuration from existing model catalog artifact." )
188+ logger .info (
189+ "Loading configuration from existing model catalog artifact."
190+ )
177191 try :
178- return load_config (ds_model .artifact , "config.json" ), ds_model .display_name
192+ return (
193+ load_config (ds_model .artifact , "config.json" ),
194+ ds_model .display_name ,
195+ )
179196 except AquaFileNotFoundError :
180- logger .warning ("config.json not found in artifact, fetching from Hugging Face Hub." )
197+ logger .warning (
198+ "config.json not found in artifact, fetching from Hugging Face Hub."
199+ )
181200 return HuggingFaceModelFetcher .fetch_config_only (model_id ), model_id
182201 else :
183202 logger .info (f"'{ model_id } ' identified as a model OCID." )
184203 ds_model = self ._validate_model_ocid (model_id )
185204 return self ._get_model_config (ds_model ), ds_model .display_name
186205
187- def _search_model_in_catalog (self , model_id : str , compartment_id : str ) -> Optional [DataScienceModel ]:
206+ def _search_model_in_catalog (
207+ self , model_id : str , compartment_id : str
208+ ) -> Optional [DataScienceModel ]:
188209 """
189210 Searches for a Hugging Face model in the Data Science model catalog by display name.
190211 """
191212 try :
192213 # This should work since the SDK's list method can filter by display_name.
193- models = DataScienceModel .list (compartment_id = compartment_id , display_name = model_id )
214+ models = DataScienceModel .list (
215+ compartment_id = compartment_id , display_name = model_id
216+ )
194217 if models :
195218 logger .info (f"Found model '{ model_id } ' in the Data Science catalog." )
196219 return models [0 ]
197220 except Exception as e :
198221 logger .warning (f"Could not search for model '{ model_id } ' in catalog: { e } " )
199222 return None
200223
201- def valid_compute_shapes (self , compartment_id : Optional [str ] = None ) -> List ["ComputeShapeSummary" ]:
224+ def valid_compute_shapes (
225+ self , compartment_id : Optional [str ] = None
226+ ) -> List ["ComputeShapeSummary" ]:
202227 """
203228 Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
204229
@@ -219,7 +244,9 @@ def valid_compute_shapes(self, compartment_id: Optional[str] = None) -> List["Co
219244 environment variables.
220245 """
221246 if not compartment_id :
222- compartment_id = os .environ .get ("NB_SESSION_COMPARTMENT_OCID" ) or os .environ .get ("PROJECT_COMPARTMENT_OCID" )
247+ compartment_id = os .environ .get (
248+ "NB_SESSION_COMPARTMENT_OCID"
249+ ) or os .environ .get ("PROJECT_COMPARTMENT_OCID" )
223250 if compartment_id :
224251 logger .info (f"Using compartment_id from environment: { compartment_id } " )
225252
0 commit comments