66import os
77import traceback
88from dataclasses import fields
9- from typing import Dict , Union
9+ from typing import Dict , Optional , Union
1010
1111import oci
1212from oci .data_science .models import UpdateModelDetails , UpdateModelProvenanceDetails
1313
1414from ads import set_auth
1515from ads .aqua import logger
16- from ads .aqua .common .enums import Tags
16+ from ads .aqua .common .enums import ConfigFolder , Tags
1717from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
1818from ads .aqua .common .utils import (
1919 _is_valid_mvs ,
@@ -268,7 +268,12 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
268268 logger .info (f"Artifact not found in model { model_id } ." )
269269 return False
270270
271- def get_config (self , model_id : str , config_file_name : str ) -> Dict :
271+ def get_config (
272+ self ,
273+ model_id : str ,
274+ config_file_name : str ,
275+ config_folder : Optional [str ] = ConfigFolder .CONFIG ,
276+ ) -> Dict :
272277 """Gets the config for the given Aqua model.
273278
274279 Parameters
@@ -277,12 +282,17 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
277282 The OCID of the Aqua model.
278283 config_file_name: str
279284 name of the config file
285+ config_folder: (str, optional):
286+ subfolder path where config_file_name needs to be searched
287+ Defaults to `ConfigFolder.CONFIG`.
288+ When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
280289
281290 Returns
282291 -------
283292 Dict:
284293 A dict of allowed configs.
285294 """
295+ config_folder = config_folder or ConfigFolder .CONFIG
286296 oci_model = self .ds_client .get_model (model_id ).data
287297 oci_aqua = (
288298 (
@@ -304,22 +314,25 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict:
304314 f"Base model found for the model: { oci_model .id } . "
305315 f"Loading { config_file_name } for base model { base_model_ocid } ."
306316 )
307- base_model = self .ds_client .get_model (base_model_ocid ).data
308- artifact_path = get_artifact_path (base_model .custom_metadata_list )
317+ if config_folder == ConfigFolder .ARTIFACT :
318+ artifact_path = get_artifact_path (oci_model .custom_metadata_list )
319+ else :
320+ base_model = self .ds_client .get_model (base_model_ocid ).data
321+ artifact_path = get_artifact_path (base_model .custom_metadata_list )
309322 else :
310323 logger .info (f"Loading { config_file_name } for model { oci_model .id } ..." )
311324 artifact_path = get_artifact_path (oci_model .custom_metadata_list )
312-
313325 if not artifact_path :
314326 logger .debug (
315327 f"Failed to get artifact path from custom metadata for the model: { model_id } "
316328 )
317329 return config
318330
319- config_path = f" { os .path .dirname (artifact_path )} /config/"
331+ config_path = os .path .join ( os . path . dirname (artifact_path ), config_folder )
320332 if not is_path_exists (config_path ):
321- config_path = f"{ artifact_path .rstrip ('/' )} /config/"
322-
333+ config_path = os .path .join (artifact_path .rstrip ("/" ), config_folder )
334+ if not is_path_exists (config_path ):
335+ config_path = f"{ artifact_path .rstrip ('/' )} /"
323336 config_file_path = f"{ config_path } { config_file_name } "
324337 if is_path_exists (config_file_path ):
325338 try :
0 commit comments