Skip to content

Commit e78ca08

Browse files
committed
added HUGGINGFACE_CONFIG_URL and is_valid_ocid, and fixed error message
1 parent d2bf709 commit e78ca08

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

ads/aqua/shaperecommend/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,5 @@
114114
"ARM": "CPU",
115115
"UNKNOWN_ENUM_VALUE": "N/A",
116116
}
117+
118+
HUGGINGFACE_CONFIG_URL = "https://huggingface.co/{model_id}/resolve/main/config.json"

ads/aqua/shaperecommend/recommend.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_resource_type,
2525
load_config,
2626
load_gpu_shapes_index,
27+
is_valid_ocid,
2728
)
2829
from ads.aqua.shaperecommend.constants import (
2930
BITS_AND_BYTES_4BIT,
@@ -32,6 +33,7 @@
3233
SHAPE_MAP,
3334
TEXT_GENERATION,
3435
TROUBLESHOOT_MSG,
36+
HUGGINGFACE_CONFIG_URL,
3537
)
3638
from ads.aqua.shaperecommend.estimator import get_estimator
3739
from ads.aqua.shaperecommend.llm_config import LLMConfig
@@ -50,11 +52,10 @@ class HuggingFaceModelFetcher:
5052
"""
5153
Utility class to fetch model configurations from HuggingFace.
5254
"""
53-
HUGGINGFACE_CONFIG_URL = "https://huggingface.co/{model_id}/resolve/main/config.json"
5455

5556
@classmethod
5657
def is_huggingface_model_id(cls, model_id: str) -> bool:
57-
if model_id.startswith("ocid1."):
58+
if is_valid_ocid(model_id):
5859
return False
5960
hf_pattern = r'^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$'
6061
return bool(re.match(hf_pattern, model_id))
@@ -66,15 +67,15 @@ def get_hf_token(cls) -> Optional[str]:
6667
@classmethod
6768
def fetch_config_only(cls, model_id: str) -> Dict[str, Any]:
6869
try:
69-
config_url = cls.HUGGINGFACE_CONFIG_URL.format(model_id=model_id)
70+
config_url = HUGGINGFACE_CONFIG_URL.format(model_id=model_id)
7071
headers = {}
7172
token = cls.get_hf_token()
7273
if token:
7374
headers["Authorization"] = f"Bearer {token}"
7475
response = requests.get(config_url, headers=headers, timeout=10)
7576
if response.status_code == 401:
7677
raise AquaValueError(
77-
f"Model '{model_id}' requires authentication. Please set your HuggingFace token."
78+
f"Model '{model_id}' requires authentication. Please set your HuggingFace access token as an environment variable."
7879
)
7980
elif response.status_code == 404:
8081
raise AquaValueError(f"Model '{model_id}' not found on HuggingFace.")

0 commit comments

Comments
 (0)