22# Copyright (c) 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
5- #!/usr/bin/env python
6- # Copyright (c) 2025 Oracle and/or its affiliates.
7- # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
8-
9- import shutil
10- import os
115import json
12- from typing import List , Union , Optional , Dict , Any , Tuple
6+ import os
7+ import shutil
8+ from typing import Dict , List , Optional , Tuple , Union
139
1410from pydantic import ValidationError
1511from rich .table import Table
1612
1713from huggingface_hub import hf_hub_download
1814from huggingface_hub .utils import HfHubHTTPError
15+
1916from ads .aqua .app import logger
2017from ads .aqua .common .entities import ComputeShapeSummary
2118from ads .aqua .common .errors import (
2522)
2623from ads .aqua .common .utils import (
2724 build_pydantic_error_message ,
25+ get_resource_type ,
26+ is_valid_ocid ,
2827 load_config ,
2928 load_gpu_shapes_index ,
30- is_valid_ocid ,
31- get_resource_type ,
29+ format_hf_custom_error_message ,
3230)
3331from ads .aqua .shaperecommend .constants import (
34- BITS_AND_BYTES_4BIT ,
3532 BITSANDBYTES ,
33+ BITS_AND_BYTES_4BIT ,
3634 SAFETENSORS ,
3735 SHAPE_MAP ,
3836 TEXT_GENERATION ,
4644 ShapeRecommendationReport ,
4745 ShapeReport ,
4846)
47+ from ads .config import COMPARTMENT_OCID
4948from ads .model .datascience_model import DataScienceModel
5049from ads .model .service .oci_datascience_model_deployment import (
5150 OCIDataScienceModelDeployment ,
@@ -100,9 +99,6 @@ def which_shapes(
10099 """
101100 try :
102101 shapes = self .valid_compute_shapes (compartment_id = request .compartment_id )
103- # data, model_name = self._get_model_config_and_name(
104- # model_id=request.model_id, compartment_id=request.compartment_id
105- # )
106102 data , model_name = self ._get_model_config_and_name (
107103 model_id = request .model_id ,
108104 )
@@ -158,41 +154,18 @@ def _get_model_config_and_name(
158154 - The display name for the model.
159155 """
160156 if is_valid_ocid (model_id ):
161- logger .info (f"' { model_id } ' identified as a model OCID ." )
157+ logger .info (f"Detected OCID: Fetching OCI model config for ' { model_id } ' ." )
162158 ds_model = self ._validate_model_ocid (model_id )
163- return self ._get_model_config (ds_model ), ds_model .display_name
159+ config = self ._fetch_hf_config (model_id )
160+ model_name = ds_model .display_name
161+ else :
162+ logger .info (
163+ f"Assuming Hugging Face model ID: Fetching config for '{ model_id } '."
164+ )
165+ config = self ._fetch_hf_config (model_id )
166+ model_name = model_id
164167
165- logger .info (
166- f"'{ model_id } ' is not an OCID, treating as a Hugging Face model ID."
167- )
168- # if not compartment_id:
169- # compartment_id = os.environ.get(
170- # "NB_SESSION_COMPARTMENT_OCID"
171- # ) or os.environ.get("PROJECT_COMPARTMENT_OCID")
172- # if compartment_id:
173- # logger.info(f"Using compartment_id from environment: {compartment_id}")
174- # if not compartment_id:
175- # raise AquaValueError(
176- # "A compartment OCID is required to list available shapes. "
177- # "Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
178- # "or 'PROJECT_COMPARTMENT_OCID' environment variable."
179- # "cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>"
180- # )
181-
182- # ds_model = self._search_model_in_catalog(model_id, compartment_id)
183- # if ds_model:
184- # logger.info("Loading configuration from existing model catalog artifact.")
185- # try:
186- # return (
187- # self._get_model_config(ds_model),
188- # ds_model.display_name,
189- # )
190- # except AquaFileNotFoundError:
191- # logger.warning(
192- # "config.json not found in artifact, fetching from Hugging Face Hub."
193- # )
194-
195- return self ._fetch_hf_config (model_id ), model_id
168+ return config , model_name
196169
197170 def _fetch_hf_config (self , model_id : str ) -> Dict :
198171 """
@@ -204,38 +177,7 @@ def _fetch_hf_config(self, model_id: str) -> Dict:
204177 with open (config_path , "r" , encoding = "utf-8" ) as f :
205178 return json .load (f )
206179 except HfHubHTTPError as e :
207- if "401" in str (e ):
208- raise AquaValueError (
209- f"Model '{ model_id } ' requires authentication. Please set your HuggingFace access token as an environment variable (HF_TOKEN). cli command: export HF_TOKEN=<HF_TOKEN>"
210- )
211- elif "404" in str (e ) or "not found" in str (e ).lower ():
212- raise AquaValueError (
213- f"Model '{ model_id } ' not found on HuggingFace. Please check the name for typos."
214- )
215- raise AquaValueError (
216- f"Failed to download config for '{ model_id } ': { e } "
217- ) from e
218-
219- # def _search_model_in_catalog(
220- # self, model_id: str, compartment_id: str
221- # ) -> Optional[DataScienceModel]:
222- # """
223- # Searches for a model in the Data Science catalog by its display name.
224- # """
225- # try:
226- # models = DataScienceModel.list(
227- # compartment_id=compartment_id, display_name=model_id
228- # )
229- # if len(models) > 1:
230- # logger.warning(
231- # f"Found multiple models with the name '{model_id}'. Using the first one found."
232- # )
233- # if models:
234- # logger.info(f"Found model '{model_id}' in the Data Science catalog.")
235- # return models[0]
236- # except Exception as e:
237- # logger.warning(f"Could not search for model '{model_id}' in catalog: {e}")
238- # return None
180+ format_hf_custom_error_message (e )
239181
240182 def valid_compute_shapes (
241183 self , compartment_id : Optional [str ] = None
@@ -260,18 +202,16 @@ def valid_compute_shapes(
260202 environment variables.
261203 """
262204 if not compartment_id :
263- compartment_id = os .environ .get (
264- "NB_SESSION_COMPARTMENT_OCID"
265- ) or os .environ .get ("PROJECT_COMPARTMENT_OCID" )
205+ compartment_id = COMPARTMENT_OCID
266206 if compartment_id :
267207 logger .info (f"Using compartment_id from environment: { compartment_id } " )
268208
269209 if not compartment_id :
270210 raise AquaValueError (
271211 "A compartment OCID is required to list available shapes. "
272- "Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
273- "or 'PROJECT_COMPARTMENT_OCID' environment variable. "
274- "cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>"
212+ "Please specify it using the --compartment_id parameter. \n \n "
213+ "Example: \n "
214+ 'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"'
275215 )
276216
277217 oci_shapes = OCIDataScienceModelDeployment .shapes (compartment_id = compartment_id )
0 commit comments