22# Copyright (c) 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
5+ #!/usr/bin/env python
6+ # Copyright (c) 2025 Oracle and/or its affiliates.
7+ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
8+
59import shutil
610import os
7- import re
811import json
9- import requests
10- from typing import List , Union , Optional , Dict , Any
12+ from typing import List , Union , Optional , Dict , Any , Tuple
1113
1214from pydantic import ValidationError
1315from rich .table import Table
16+ from huggingface_hub import hf_hub_download , HfHubHTTPError
1417
1518from ads .aqua .app import logger
1619from ads .aqua .common .entities import ComputeShapeSummary
2124)
2225from ads .aqua .common .utils import (
2326 build_pydantic_error_message ,
24- get_resource_type ,
2527 load_config ,
2628 load_gpu_shapes_index ,
2729 is_valid_ocid ,
3335 SHAPE_MAP ,
3436 TEXT_GENERATION ,
3537 TROUBLESHOOT_MSG ,
36- HUGGINGFACE_CONFIG_URL ,
3738)
3839from ads .aqua .shaperecommend .estimator import get_estimator
3940from ads .aqua .shaperecommend .llm_config import LLMConfig
4950)
5051
5152
52- class HuggingFaceModelFetcher :
53- """
54- Utility class to fetch model configurations from HuggingFace.
55- """
56-
57- @classmethod
58- def is_huggingface_model_id (cls , model_id : str ) -> bool :
59- if is_valid_ocid (model_id ):
60- return False
61- hf_pattern = r"^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$"
62- return bool (re .match (hf_pattern , model_id ))
63-
64- @classmethod
65- def get_hf_token (cls ) -> Optional [str ]:
66- return os .environ .get ("HUGGING_FACE_HUB_TOKEN" ) or os .environ .get ("HF_TOKEN" )
67-
68- @classmethod
69- def fetch_config_only (cls , model_id : str ) -> Dict [str , Any ]:
70- try :
71- config_url = HUGGINGFACE_CONFIG_URL .format (model_id = model_id )
72- headers = {}
73- token = cls .get_hf_token ()
74- if token :
75- headers ["Authorization" ] = f"Bearer { token } "
76- response = requests .get (config_url , headers = headers , timeout = 10 )
77- if response .status_code == 401 :
78- raise AquaValueError (
79- f"Model '{ model_id } ' requires authentication. Please set your HuggingFace access token as an environment variable."
80- )
81- elif response .status_code == 404 :
82- raise AquaValueError (
83- f"Model '{ model_id } ' not found on HuggingFace. Please check the name for typos."
84- )
85- elif response .status_code != 200 :
86- raise AquaValueError (
87- f"Failed to fetch config for '{ model_id } '. Status: { response .status_code } "
88- )
89- return response .json ()
90- except requests .RequestException as e :
91- raise AquaValueError (
92- f"Network error fetching config for { model_id } : { e } "
93- ) from e
94- except json .JSONDecodeError as e :
95- raise AquaValueError (
96- f"Invalid config format for model '{ model_id } '."
97- ) from e
98-
99-
10053class AquaShapeRecommend :
10154 """
10255 Interface for recommending GPU shapes for machine learning model deployments
@@ -146,7 +99,7 @@ def which_shapes(
14699 try :
147100 shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
148101 data , model_name = self ._get_model_config_and_name (
149- request .model_id , request .compartment_id
102+ model_id = request .model_id , compartment_id = request .compartment_id
150103 )
151104 llm_config = LLMConfig .from_raw_config (data )
152105 shape_recommendation_report = self ._summarize_shapes_for_seq_lens (
@@ -183,8 +136,15 @@ def _get_model_config_and_name(
183136 """
184137 Loads model configuration, handling OCID and Hugging Face model IDs.
185138 """
186- if HuggingFaceModelFetcher .is_huggingface_model_id (model_id ):
187- logger .info (f"'{ model_id } ' identified as a Hugging Face model ID." )
139+ if is_valid_ocid (model_id ):
140+ logger .info (f"'{ model_id } ' identified as a model OCID." )
141+ ds_model = self ._validate_model_ocid (model_id )
142+ return self ._get_model_config (ds_model ), ds_model .display_name
143+
144+ logger .info (
145+ f"'{ model_id } ' is not an OCID, treating as a Hugging Face model ID."
146+ )
147+ if compartment_id :
188148 ds_model = self ._search_model_in_catalog (model_id , compartment_id )
189149 if ds_model :
190150 logger .info (
@@ -199,23 +159,45 @@ def _get_model_config_and_name(
199159 logger .warning (
200160 "config.json not found in artifact, fetching from Hugging Face Hub."
201161 )
202- return HuggingFaceModelFetcher .fetch_config_only (model_id ), model_id
203- else :
204- logger .info (f"'{ model_id } ' identified as a model OCID." )
205- ds_model = self ._validate_model_ocid (model_id )
206- return self ._get_model_config (ds_model ), ds_model .display_name
162+
163+ return self ._fetch_hf_config (model_id ), model_id
164+
165+ def _fetch_hf_config (self , model_id : str ) -> Dict :
166+ """
167+ Downloads a model's config.json from Hugging Face Hub using the
168+ huggingface_hub library.
169+ """
170+ try :
171+ config_path = hf_hub_download (repo_id = model_id , filename = "config.json" )
172+ with open (config_path , "r" , encoding = "utf-8" ) as f :
173+ return json .load (f )
174+ except HfHubHTTPError as e :
175+ if "401" in str (e ):
176+ raise AquaValueError (
177+ f"Model '{ model_id } ' requires authentication. Please set your HuggingFace access token as an environment variable (HF_TOKEN). cli command: export HF_TOKEN=<HF_TOKEN>"
178+ )
179+ elif "404" in str (e ) or "not found" in str (e ).lower ():
180+ raise AquaValueError (
181+ f"Model '{ model_id } ' not found on HuggingFace. Please check the name for typos."
182+ )
183+ raise AquaValueError (
184+ f"Failed to download config for '{ model_id } ': { e } "
185+ ) from e
207186
208187 def _search_model_in_catalog (
209188 self , model_id : str , compartment_id : str
210189 ) -> Optional [DataScienceModel ]:
211190 """
212- Searches for a model in the Data Science model catalog by display name.
191+ Searches for a model in the Data Science catalog by its display name.
213192 """
214193 try :
215- # This should work since the SDK's list method can filter by display_name.
216194 models = DataScienceModel .list (
217195 compartment_id = compartment_id , display_name = model_id
218196 )
197+ if len (models ) > 1 :
198+ logger .warning (
199+ f"Found multiple models with the name '{ model_id } '. Using the first one found."
200+ )
219201 if models :
220202 logger .info (f"Found model '{ model_id } ' in the Data Science catalog." )
221203 return models [0 ]
0 commit comments