55import json
66import os
77import traceback
8+ from concurrent .futures import ThreadPoolExecutor
89from dataclasses import fields
910from datetime import datetime , timedelta
1011from itertools import chain
2223from ads .aqua import logger
2324from ads .aqua .common .entities import ModelConfigResult
2425from ads .aqua .common .enums import ConfigFolder , Tags
25- from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
26+ from ads .aqua .common .errors import AquaValueError
2627from ads .aqua .common .utils import (
2728 _is_valid_mvs ,
2829 get_artifact_path ,
5859class AquaApp :
5960 """Base Aqua App to contain common components."""
6061
62+ MAX_WORKERS = 10 # Number of workers for asynchronous resource loading
63+
6164 @telemetry (name = "aqua" )
6265 def __init__ (self ) -> None :
6366 if OCI_RESOURCE_PRINCIPAL_VERSION :
6467 set_auth ("resource_principal" )
6568 self ._auth = default_signer ({"service_endpoint" : OCI_ODSC_SERVICE_ENDPOINT })
6669 self .ds_client = oc .OCIClientFactory (** self ._auth ).data_science
70+ self .compute_client = oc .OCIClientFactory (** default_signer ()).compute
6771 self .logging_client = oc .OCIClientFactory (** default_signer ()).logging_management
6872 self .identity_client = oc .OCIClientFactory (** default_signer ()).identity
6973 self .region = extract_region (self ._auth )
@@ -127,20 +131,69 @@ def update_model_provenance(
127131 update_model_provenance_details = update_model_provenance_details ,
128132 )
129133
130- # TODO: refactor model evaluation implementation to use it.
131134 @staticmethod
132135 def get_source (source_id : str ) -> Union [ModelDeployment , DataScienceModel ]:
133- if is_valid_ocid (source_id ):
134- if "datasciencemodeldeployment" in source_id :
135- return ModelDeployment .from_id (source_id )
136- elif "datasciencemodel" in source_id :
137- return DataScienceModel .from_id (source_id )
136+ """
137+ Fetches a model or model deployment based on the provided OCID.
138+
139+ Parameters
140+ ----------
141+ source_id : str
142+ OCID of the Data Science model or model deployment.
143+
144+ Returns
145+ -------
146+ Union[ModelDeployment, DataScienceModel]
147+ The corresponding resource object.
138148
149+ Raises
150+ ------
151+ AquaValueError
152+ If the OCID is invalid or unsupported.
153+ """
154+ logger .debug (f"Resolving source for ID: { source_id } " )
155+ if not is_valid_ocid (source_id ):
156+ logger .error (f"Invalid OCID format: { source_id } " )
157+ raise AquaValueError (
158+ f"Invalid source ID: { source_id } . Please provide a valid model or model deployment OCID."
159+ )
160+
161+ if "datasciencemodeldeployment" in source_id :
162+ logger .debug (f"Identified as ModelDeployment OCID: { source_id } " )
163+ return ModelDeployment .from_id (source_id )
164+
165+ if "datasciencemodel" in source_id :
166+ logger .debug (f"Identified as DataScienceModel OCID: { source_id } " )
167+ return DataScienceModel .from_id (source_id )
168+
169+ logger .error (f"Unrecognized OCID type: { source_id } " )
139170 raise AquaValueError (
140- f"Invalid source { source_id } . "
141- "Specify either a model or model deployment id."
171+ f"Unsupported source ID type: { source_id } . Must be a model or model deployment OCID."
142172 )
143173
174+ def get_multi_source (
175+ self ,
176+ ids : List [str ],
177+ ) -> Dict [str , Union [ModelDeployment , DataScienceModel ]]:
178+ """
179+ Retrieves multiple DataScience resources concurrently.
180+
181+ Parameters
182+ ----------
183+ ids : List[str]
184+ A list of DataScience OCIDs.
185+
186+ Returns
187+ -------
188+ Dict[str, Union[ModelDeployment, DataScienceModel]]
189+ A mapping from OCID to the corresponding resolved resource object.
190+ """
191+ logger .debug (f"Fetching { ids } sources in parallel." )
192+ with ThreadPoolExecutor (max_workers = self .MAX_WORKERS ) as executor :
193+ results = list (executor .map (self .get_source , ids ))
194+
195+ return dict (zip (ids , results ))
196+
144197 # TODO: refactor model evaluation implementation to use it.
145198 @staticmethod
146199 def create_model_version_set (
@@ -283,8 +336,11 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
283336 logger .info (f"Artifact not found in model { model_id } ." )
284337 return False
285338
339+ @cached (cache = TTLCache (maxsize = 5 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
286340 def get_config_from_metadata (
287- self , model_id : str , metadata_key : str
341+ self ,
342+ model_id : str ,
343+ metadata_key : str ,
288344 ) -> ModelConfigResult :
289345 """Gets the config for the given Aqua model from model catalog metadata content.
290346
@@ -299,8 +355,9 @@ def get_config_from_metadata(
299355 ModelConfigResult
300356 A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
301357 """
302- config = {}
358+ config : Dict [ str , Any ] = {}
303359 oci_model = self .ds_client .get_model (model_id ).data
360+
304361 try :
305362 config = self .ds_client .get_model_defined_metadatum_artifact_content (
306363 model_id , metadata_key
@@ -320,7 +377,7 @@ def get_config_from_metadata(
320377 )
321378 return ModelConfigResult (config = config , model_details = oci_model )
322379
323- @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
380+ @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 5 ), timer = datetime .now ))
324381 def get_config (
325382 self ,
326383 model_id : str ,
@@ -345,8 +402,10 @@ def get_config(
345402 ModelConfigResult
346403 A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
347404 """
348- config_folder = config_folder or ConfigFolder . CONFIG
405+ config : Dict [ str , Any ] = {}
349406 oci_model = self .ds_client .get_model (model_id ).data
407+
408+ config_folder = config_folder or ConfigFolder .CONFIG
350409 oci_aqua = (
351410 (
352411 Tags .AQUA_TAG in oci_model .freeform_tags
@@ -356,9 +415,9 @@ def get_config(
356415 else False
357416 )
358417 if not oci_aqua :
359- raise AquaRuntimeError (f"Target model { oci_model .id } is not an Aqua model." )
418+ logger .debug (f"Target model { oci_model .id } is not an Aqua model." )
419+ return ModelConfigResult (config = config , model_details = oci_model )
360420
361- config : Dict [str , Any ] = {}
362421 artifact_path = get_artifact_path (oci_model .custom_metadata_list )
363422 if not artifact_path :
364423 logger .debug (
0 commit comments