55import json
66import os
77import traceback
8+ from concurrent .futures import ThreadPoolExecutor
89from dataclasses import fields
910from datetime import datetime , timedelta
1011from itertools import chain
2223from ads .aqua import logger
2324from ads .aqua .common .entities import ModelConfigResult
2425from ads .aqua .common .enums import ConfigFolder , Tags
25- from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
26+ from ads .aqua .common .errors import AquaValueError
2627from ads .aqua .common .utils import (
2728 _is_valid_mvs ,
2829 get_artifact_path ,
5859class AquaApp :
5960 """Base Aqua App to contain common components."""
6061
62+ MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63+
6164 @telemetry (name = "aqua" )
6265 def __init__ (self ) -> None :
6366 if OCI_RESOURCE_PRINCIPAL_VERSION :
@@ -128,20 +131,69 @@ def update_model_provenance(
128131 update_model_provenance_details = update_model_provenance_details ,
129132 )
130133
131- # TODO: refactor model evaluation implementation to use it.
132134 @staticmethod
133135 def get_source (source_id : str ) -> Union [ModelDeployment , DataScienceModel ]:
134- if is_valid_ocid (source_id ):
135- if "datasciencemodeldeployment" in source_id :
136- return ModelDeployment .from_id (source_id )
137- elif "datasciencemodel" in source_id :
138- return DataScienceModel .from_id (source_id )
136+ """
137+ Fetches a model or model deployment based on the provided OCID.
138+
139+ Parameters
140+ ----------
141+ source_id : str
142+ OCID of the Data Science model or model deployment.
143+
144+ Returns
145+ -------
146+ Union[ModelDeployment, DataScienceModel]
147+ The corresponding resource object.
139148
149+ Raises
150+ ------
151+ AquaValueError
152+ If the OCID is invalid or unsupported.
153+ """
154+ logger .debug (f"Resolving source for ID: { source_id } " )
155+ if not is_valid_ocid (source_id ):
156+ logger .error (f"Invalid OCID format: { source_id } " )
157+ raise AquaValueError (
158+ f"Invalid source ID: { source_id } . Please provide a valid model or model deployment OCID."
159+ )
160+
161+ if "datasciencemodeldeployment" in source_id :
162+ logger .debug (f"Identified as ModelDeployment OCID: { source_id } " )
163+ return ModelDeployment .from_id (source_id )
164+
165+ if "datasciencemodel" in source_id :
166+ logger .debug (f"Identified as DataScienceModel OCID: { source_id } " )
167+ return DataScienceModel .from_id (source_id )
168+
169+ logger .error (f"Unrecognized OCID type: { source_id } " )
140170 raise AquaValueError (
141- f"Invalid source { source_id } . "
142- "Specify either a model or model deployment id."
171+ f"Unsupported source ID type: { source_id } . Must be a model or model deployment OCID."
143172 )
144173
174+ def get_multi_source (
175+ self ,
176+ ids : List [str ],
177+ ) -> Dict [str , Union [ModelDeployment , DataScienceModel ]]:
178+ """
179+ Retrieves multiple DataScience resources concurrently.
180+
181+ Parameters
182+ ----------
183+ ids : List[str]
184+ A list of DataScience OCIDs.
185+
186+ Returns
187+ -------
188+ Dict[str, Union[ModelDeployment, DataScienceModel]]
189+ A mapping from OCID to the corresponding resolved resource object.
190+ """
191+ logger .debug (f"Fetching { ids } sources in parallel." )
192+ with ThreadPoolExecutor (max_workers = self .MAX_WORKERS ) as executor :
193+ results = list (executor .map (self .get_source , ids ))
194+
195+ return dict (zip (ids , results ))
196+
145197 # TODO: refactor model evaluation implementation to use it.
146198 @staticmethod
147199 def create_model_version_set (
@@ -284,8 +336,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
284336 logger .info (f"Artifact not found in model { model_id } ." )
285337 return False
286338
339+ @cached (cache = TTLCache (maxsize = 5 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
287340 def get_config_from_metadata (
288- self , model_id : str , metadata_key : str
341+ self ,
342+ model_id : str ,
343+ metadata_key : str ,
289344 ) -> ModelConfigResult :
290345 """Gets the config for the given Aqua model from model catalog metadata content.
291346
@@ -300,8 +355,9 @@ def get_config_from_metadata(
300355 ModelConfigResult
301356 A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
302357 """
303- config = {}
358+ config : Dict [ str , Any ] = {}
304359 oci_model = self .ds_client .get_model (model_id ).data
360+
305361 try :
306362 config = self .ds_client .get_model_defined_metadatum_artifact_content (
307363 model_id , metadata_key
@@ -321,7 +377,7 @@ def get_config_from_metadata(
321377 )
322378 return ModelConfigResult (config = config , model_details = oci_model )
323379
324- @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
380+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 5 ), timer = datetime .now ))
325381 def get_config (
326382 self ,
327383 model_id : str ,
@@ -346,8 +402,10 @@ def get_config(
346402 ModelConfigResult
347403 A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
348404 """
349- config_folder = config_folder or ConfigFolder . CONFIG
405+ config : Dict [ str , Any ] = {}
350406 oci_model = self .ds_client .get_model (model_id ).data
407+
408+ config_folder = config_folder or ConfigFolder .CONFIG
351409 oci_aqua = (
352410 (
353411 Tags .AQUA_TAG in oci_model .freeform_tags
@@ -357,9 +415,9 @@ def get_config(
357415 else False
358416 )
359417 if not oci_aqua :
360- raise AquaRuntimeError (f"Target model { oci_model .id } is not an Aqua model." )
418+ logger .debug (f"Target model { oci_model .id } is not an Aqua model." )
419+ return ModelConfigResult (config = config , model_details = oci_model )
361420
362- config : Dict [str , Any ] = {}
363421 artifact_path = get_artifact_path (oci_model .custom_metadata_list )
364422 if not artifact_path :
365423 logger .debug (
0 commit comments