|
3 | 3 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
4 | 4 |
|
5 | 5 | import shutil |
6 | | -from typing import List, Union |
| 6 | +import os |
| 7 | +import re |
| 8 | +import json |
| 9 | +import requests |
| 10 | +from typing import List, Union, Optional, Dict, Any |
7 | 11 |
|
8 | 12 | from pydantic import ValidationError |
9 | 13 | from rich.table import Table |
|
42 | 46 | OCIDataScienceModelDeployment, |
43 | 47 | ) |
44 | 48 |
|
| 49 | +class HuggingFaceModelFetcher: |
| 50 | + """ |
| 51 | + Utility class to fetch model configurations from HuggingFace. |
| 52 | + """ |
| 53 | + HUGGINGFACE_CONFIG_URL = "https://huggingface.co/{model_id}/resolve/main/config.json" |
| 54 | + |
| 55 | + @classmethod |
| 56 | + def is_huggingface_model_id(cls, model_id: str) -> bool: |
| 57 | + if model_id.startswith("ocid1."): |
| 58 | + return False |
| 59 | + hf_pattern = r'^[a-zA-Z0-9_-]+(/[a-zA-Z0-9_.-]+)?$' |
| 60 | + return bool(re.match(hf_pattern, model_id)) |
| 61 | + |
| 62 | + @classmethod |
| 63 | + def get_hf_token(cls) -> Optional[str]: |
| 64 | + return os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") |
| 65 | + |
| 66 | + @classmethod |
| 67 | + def fetch_config_only(cls, model_id: str) -> Dict[str, Any]: |
| 68 | + try: |
| 69 | + config_url = cls.HUGGINGFACE_CONFIG_URL.format(model_id=model_id) |
| 70 | + headers = {} |
| 71 | + token = cls.get_hf_token() |
| 72 | + if token: |
| 73 | + headers["Authorization"] = f"Bearer {token}" |
| 74 | + response = requests.get(config_url, headers=headers, timeout=10) |
| 75 | + if response.status_code == 401: |
| 76 | + raise AquaValueError( |
| 77 | + f"Model '{model_id}' requires authentication. Please set your HuggingFace token." |
| 78 | + ) |
| 79 | + elif response.status_code == 404: |
| 80 | + raise AquaValueError(f"Model '{model_id}' not found on HuggingFace.") |
| 81 | + elif response.status_code != 200: |
| 82 | + raise AquaValueError(f"Failed to fetch config for '{model_id}'. Status: {response.status_code}") |
| 83 | + return response.json() |
| 84 | + except requests.RequestException as e: |
| 85 | + raise AquaValueError(f"Network error fetching config for {model_id}: {e}") from e |
| 86 | + except json.JSONDecodeError as e: |
| 87 | + raise AquaValueError(f"Invalid config format for model '{model_id}'.") from e |
45 | 88 |
|
46 | 89 | class AquaShapeRecommend: |
47 | 90 | """ |
@@ -91,14 +134,8 @@ def which_shapes( |
91 | 134 | """ |
92 | 135 | try: |
93 | 136 | shapes = self.valid_compute_shapes(compartment_id=request.compartment_id) |
94 | | - |
95 | | - ds_model = self._validate_model_ocid(request.model_id) |
96 | | - data = self._get_model_config(ds_model) |
97 | | - |
| 137 | + data, model_name = self._get_model_config_and_name(request.model_id, request.compartment_id) |
98 | 138 | llm_config = LLMConfig.from_raw_config(data) |
99 | | - |
100 | | - model_name = ds_model.display_name if ds_model.display_name else "" |
101 | | - |
102 | 139 | shape_recommendation_report = self._summarize_shapes_for_seq_lens( |
103 | 140 | llm_config, shapes, model_name |
104 | 141 | ) |
@@ -127,6 +164,39 @@ def which_shapes( |
127 | 164 |
|
128 | 165 | return shape_recommendation_report |
129 | 166 |
|
| 167 | + def _get_model_config_and_name(self, model_id: str, compartment_id: str) -> (dict, str): |
| 168 | + """ |
| 169 | + Loads model configuration, handling OCID and Hugging Face model IDs. |
| 170 | + """ |
| 171 | + if HuggingFaceModelFetcher.is_huggingface_model_id(model_id): |
| 172 | + logger.info(f"'{model_id}' identified as a Hugging Face model ID.") |
| 173 | + ds_model = self._search_model_in_catalog(model_id, compartment_id) |
| 174 | + if ds_model and ds_model.artifact: |
| 175 | + logger.info("Loading configuration from existing model catalog artifact.") |
| 176 | + try: |
| 177 | + return load_config(ds_model.artifact, "config.json"), ds_model.display_name |
| 178 | + except AquaFileNotFoundError: |
| 179 | + logger.warning("config.json not found in artifact, fetching from Hugging Face Hub.") |
| 180 | + return HuggingFaceModelFetcher.fetch_config_only(model_id), model_id |
| 181 | + else: |
| 182 | + logger.info(f"'{model_id}' identified as a model OCID.") |
| 183 | + ds_model = self._validate_model_ocid(model_id) |
| 184 | + return self._get_model_config(ds_model), ds_model.display_name |
| 185 | + |
| 186 | + def _search_model_in_catalog(self, model_id: str, compartment_id: str) -> Optional[DataScienceModel]: |
| 187 | + """ |
| 188 | + Searches for a Hugging Face model in the Data Science model catalog by display name. |
| 189 | + """ |
| 190 | + try: |
| 191 | + # This should work since the SDK's list method can filter by display_name. |
| 192 | + models = DataScienceModel.list(compartment_id=compartment_id, display_name=model_id) |
| 193 | + if models: |
| 194 | + logger.info(f"Found model '{model_id}' in the Data Science catalog.") |
| 195 | + return models[0] |
| 196 | + except Exception as e: |
| 197 | + logger.warning(f"Could not search for model '{model_id}' in catalog: {e}") |
| 198 | + return None |
| 199 | + |
130 | 200 | def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]: |
131 | 201 | """ |
132 | 202 | Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file. |
|
0 commit comments