Skip to content

Commit 19ddce1

Browse files
committed
using huggingface_hub.hf_hub_download and design changes
1 parent efe0953 commit 19ddce1

File tree

1 file changed

+45
-63
lines changed

1 file changed

+45
-63
lines changed

ads/aqua/shaperecommend/recommend.py

Lines changed: 45 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
# Copyright (c) 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

5+
#!/usr/bin/env python
6+
# Copyright (c) 2025 Oracle and/or its affiliates.
7+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
8+
59
import shutil
610
import os
7-
import re
811
import json
9-
import requests
10-
from typing import List, Union, Optional, Dict, Any
12+
from typing import List, Union, Optional, Dict, Any, Tuple
1113

1214
from pydantic import ValidationError
1315
from rich.table import Table
16+
from huggingface_hub import hf_hub_download, HfHubHTTPError
1417

1518
from ads.aqua.app import logger
1619
from ads.aqua.common.entities import ComputeShapeSummary
@@ -21,7 +24,6 @@
2124
)
2225
from ads.aqua.common.utils import (
2326
build_pydantic_error_message,
24-
get_resource_type,
2527
load_config,
2628
load_gpu_shapes_index,
2729
is_valid_ocid,
@@ -33,7 +35,6 @@
3335
SHAPE_MAP,
3436
TEXT_GENERATION,
3537
TROUBLESHOOT_MSG,
36-
HUGGINGFACE_CONFIG_URL,
3738
)
3839
from ads.aqua.shaperecommend.estimator import get_estimator
3940
from ads.aqua.shaperecommend.llm_config import LLMConfig
@@ -49,54 +50,6 @@
4950
)
5051

5152

52-
class HuggingFaceModelFetcher:
53-
"""
54-
Utility class to fetch model configurations from HuggingFace.
55-
"""
56-
57-
@classmethod
58-
def is_huggingface_model_id(cls, model_id: str) -> bool:
59-
if is_valid_ocid(model_id):
60-
return False
61-
hf_pattern = r"^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$"
62-
return bool(re.match(hf_pattern, model_id))
63-
64-
@classmethod
65-
def get_hf_token(cls) -> Optional[str]:
66-
return os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN")
67-
68-
@classmethod
69-
def fetch_config_only(cls, model_id: str) -> Dict[str, Any]:
70-
try:
71-
config_url = HUGGINGFACE_CONFIG_URL.format(model_id=model_id)
72-
headers = {}
73-
token = cls.get_hf_token()
74-
if token:
75-
headers["Authorization"] = f"Bearer {token}"
76-
response = requests.get(config_url, headers=headers, timeout=10)
77-
if response.status_code == 401:
78-
raise AquaValueError(
79-
f"Model '{model_id}' requires authentication. Please set your HuggingFace access token as an environment variable."
80-
)
81-
elif response.status_code == 404:
82-
raise AquaValueError(
83-
f"Model '{model_id}' not found on HuggingFace. Please check the name for typos."
84-
)
85-
elif response.status_code != 200:
86-
raise AquaValueError(
87-
f"Failed to fetch config for '{model_id}'. Status: {response.status_code}"
88-
)
89-
return response.json()
90-
except requests.RequestException as e:
91-
raise AquaValueError(
92-
f"Network error fetching config for {model_id}: {e}"
93-
) from e
94-
except json.JSONDecodeError as e:
95-
raise AquaValueError(
96-
f"Invalid config format for model '{model_id}'."
97-
) from e
98-
99-
10053
class AquaShapeRecommend:
10154
"""
10255
Interface for recommending GPU shapes for machine learning model deployments
@@ -146,7 +99,7 @@ def which_shapes(
14699
try:
147100
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
148101
data, model_name = self._get_model_config_and_name(
149-
request.model_id, request.compartment_id
102+
model_id=request.model_id, compartment_id=request.compartment_id
150103
)
151104
llm_config = LLMConfig.from_raw_config(data)
152105
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
@@ -183,8 +136,15 @@ def _get_model_config_and_name(
183136
"""
184137
Loads model configuration, handling OCID and Hugging Face model IDs.
185138
"""
186-
if HuggingFaceModelFetcher.is_huggingface_model_id(model_id):
187-
logger.info(f"'{model_id}' identified as a Hugging Face model ID.")
139+
if is_valid_ocid(model_id):
140+
logger.info(f"'{model_id}' identified as a model OCID.")
141+
ds_model = self._validate_model_ocid(model_id)
142+
return self._get_model_config(ds_model), ds_model.display_name
143+
144+
logger.info(
145+
f"'{model_id}' is not an OCID, treating as a Hugging Face model ID."
146+
)
147+
if compartment_id:
188148
ds_model = self._search_model_in_catalog(model_id, compartment_id)
189149
if ds_model:
190150
logger.info(
@@ -199,23 +159,45 @@ def _get_model_config_and_name(
199159
logger.warning(
200160
"config.json not found in artifact, fetching from Hugging Face Hub."
201161
)
202-
return HuggingFaceModelFetcher.fetch_config_only(model_id), model_id
203-
else:
204-
logger.info(f"'{model_id}' identified as a model OCID.")
205-
ds_model = self._validate_model_ocid(model_id)
206-
return self._get_model_config(ds_model), ds_model.display_name
162+
163+
return self._fetch_hf_config(model_id), model_id
164+
165+
def _fetch_hf_config(self, model_id: str) -> Dict:
166+
"""
167+
Downloads a model's config.json from Hugging Face Hub using the
168+
huggingface_hub library.
169+
"""
170+
try:
171+
config_path = hf_hub_download(repo_id=model_id, filename="config.json")
172+
with open(config_path, "r", encoding="utf-8") as f:
173+
return json.load(f)
174+
except HfHubHTTPError as e:
175+
if "401" in str(e):
176+
raise AquaValueError(
177+
f"Model '{model_id}' requires authentication. Please set your HuggingFace access token as an environment variable (HF_TOKEN). cli command: export HF_TOKEN=<HF_TOKEN>"
178+
)
179+
elif "404" in str(e) or "not found" in str(e).lower():
180+
raise AquaValueError(
181+
f"Model '{model_id}' not found on HuggingFace. Please check the name for typos."
182+
)
183+
raise AquaValueError(
184+
f"Failed to download config for '{model_id}': {e}"
185+
) from e
207186

208187
def _search_model_in_catalog(
209188
self, model_id: str, compartment_id: str
210189
) -> Optional[DataScienceModel]:
211190
"""
212-
Searches for a model in the Data Science model catalog by display name.
191+
Searches for a model in the Data Science catalog by its display name.
213192
"""
214193
try:
215-
# This should work since the SDK's list method can filter by display_name.
216194
models = DataScienceModel.list(
217195
compartment_id=compartment_id, display_name=model_id
218196
)
197+
if len(models) > 1:
198+
logger.warning(
199+
f"Found multiple models with the name '{model_id}'. Using the first one found."
200+
)
219201
if models:
220202
logger.info(f"Found model '{model_id}' in the Data Science catalog.")
221203
return models[0]

0 commit comments

Comments
 (0)