Skip to content

Commit e324143

Browse files
committed
added compartment id logic _get_model_config as a sanity check
1 parent 3d935bb commit e324143

File tree

1 file changed

+26
-15
lines changed

1 file changed

+26
-15
lines changed

ads/aqua/shaperecommend/recommend.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def which_shapes(
131131
return shape_recommendation_report
132132

133133
def _get_model_config_and_name(
134-
self, model_id: str, compartment_id: str
134+
self, model_id: str, compartment_id: Optional[str] = None
135135
) -> (dict, str):
136136
"""
137137
Loads model configuration, handling OCID and Hugging Face model IDs.
@@ -144,21 +144,32 @@ def _get_model_config_and_name(
144144
logger.info(
145145
f"'{model_id}' is not an OCID, treating as a Hugging Face model ID."
146146
)
147-
if compartment_id:
148-
ds_model = self._search_model_in_catalog(model_id, compartment_id)
149-
if ds_model:
150-
logger.info(
151-
"Loading configuration from existing model catalog artifact."
147+
if not compartment_id:
148+
compartment_id = os.environ.get(
149+
"NB_SESSION_COMPARTMENT_OCID"
150+
) or os.environ.get("PROJECT_COMPARTMENT_OCID")
151+
if compartment_id:
152+
logger.info(f"Using compartment_id from environment: {compartment_id}")
153+
if not compartment_id:
154+
raise AquaValueError(
155+
"A compartment OCID is required to list available shapes. "
156+
"Please provide it as a parameter or set the 'NB_SESSION_COMPARTMENT_OCID' "
157+
"or 'PROJECT_COMPARTMENT_OCID' environment variable."
158+
"cli command: export NB_SESSION_COMPARTMENT_OCID=<NB_SESSION_COMPARTMENT_OCID>"
159+
)
160+
161+
ds_model = self._search_model_in_catalog(model_id, compartment_id)
162+
if ds_model:
163+
logger.info("Loading configuration from existing model catalog artifact.")
164+
try:
165+
return (
166+
self._get_model_config(ds_model),
167+
ds_model.display_name,
168+
)
169+
except AquaFileNotFoundError:
170+
logger.warning(
171+
"config.json not found in artifact, fetching from Hugging Face Hub."
152172
)
153-
try:
154-
return (
155-
self._get_model_config(ds_model),
156-
ds_model.display_name,
157-
)
158-
except AquaFileNotFoundError:
159-
logger.warning(
160-
"config.json not found in artifact, fetching from Hugging Face Hub."
161-
)
162173

163174
return self._fetch_hf_config(model_id), model_id
164175

0 commit comments

Comments
 (0)