77import traceback
88from dataclasses import fields
99from datetime import datetime , timedelta
10- from typing import Any , Dict , Optional , Union
10+ from itertools import chain
11+ from typing import Any , Dict , List , Optional , Union
1112
1213import oci
1314from cachetools import TTLCache , cached
14- from oci .data_science .models import UpdateModelDetails , UpdateModelProvenanceDetails
15+ from oci .data_science .models import (
16+ ContainerSummary ,
17+ UpdateModelDetails ,
18+ UpdateModelProvenanceDetails ,
19+ )
1520
1621from ads import set_auth
1722from ads .aqua import logger
2429 is_valid_ocid ,
2530 load_config ,
2631)
32+ from ads .aqua .config .container_config import (
33+ AquaContainerConfig ,
34+ AquaContainerConfigItem ,
35+ )
36+ from ads .aqua .constants import SERVICE_MANAGED_CONTAINER_URI_SCHEME
2737from ads .common import oci_client as oc
2838from ads .common .auth import default_signer
2939from ads .common .utils import UNKNOWN , extract_region , is_path_exists
@@ -240,7 +250,9 @@ def create_model_catalog(
240250 .with_custom_metadata_list (model_custom_metadata )
241251 .with_defined_metadata_list (model_taxonomy_metadata )
242252 .with_provenance_metadata (ModelProvenanceMetadata (training_id = UNKNOWN ))
243- .with_defined_tags (** (defined_tags or {})) # Create defined tags when a model is created.
253+ .with_defined_tags (
254+ ** (defined_tags or {})
255+ ) # Create defined tags when a model is created.
244256 .create (
245257 ** kwargs ,
246258 )
@@ -271,6 +283,43 @@ def if_artifact_exist(self, model_id: str, **kwargs) -> bool:
271283 logger .info (f"Artifact not found in model { model_id } ." )
272284 return False
273285
286+ def get_config_from_metadata (
287+ self , model_id : str , metadata_key : str
288+ ) -> ModelConfigResult :
289+ """Gets the config for the given Aqua model from model catalog metadata content.
290+
291+ Parameters
292+ ----------
293+ model_id: str
294+ The OCID of the Aqua model.
295+ metadata_key: str
296+ The metadata key name where artifact content is stored
297+ Returns
298+ -------
299+ ModelConfigResult
300+ A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
301+ """
302+ config = {}
303+ oci_model = self .ds_client .get_model (model_id ).data
304+ try :
305+ config = self .ds_client .get_model_defined_metadatum_artifact_content (
306+ model_id , metadata_key
307+ ).data .content .decode ("utf-8" )
308+ return ModelConfigResult (config = json .loads (config ), model_details = oci_model )
309+ except UnicodeDecodeError as ex :
310+ logger .error (
311+ f"Failed to decode content for '{ metadata_key } ' in defined metadata for model '{ model_id } ' : { ex } "
312+ )
313+ except json .JSONDecodeError as ex :
314+ logger .error (
315+ f"Invalid JSON format for '{ metadata_key } ' in defined metadata for model '{ model_id } ' : { ex } "
316+ )
317+ except Exception as ex :
318+ logger .error (
319+ f"Failed to retrieve defined metadata key '{ metadata_key } ' for model '{ model_id } ': { ex } "
320+ )
321+ return ModelConfigResult (config = config , model_details = oci_model )
322+
274323 @cached (cache = TTLCache (maxsize = 1 , ttl = timedelta (minutes = 1 ), timer = datetime .now ))
275324 def get_config (
276325 self ,
@@ -310,22 +359,7 @@ def get_config(
310359 raise AquaRuntimeError (f"Target model { oci_model .id } is not an Aqua model." )
311360
312361 config : Dict [str , Any ] = {}
313-
314- # if the current model has a service model tag, then
315- if Tags .AQUA_SERVICE_MODEL_TAG in oci_model .freeform_tags :
316- base_model_ocid = oci_model .freeform_tags [Tags .AQUA_SERVICE_MODEL_TAG ]
317- logger .info (
318- f"Base model found for the model: { oci_model .id } . "
319- f"Loading { config_file_name } for base model { base_model_ocid } ."
320- )
321- if config_folder == ConfigFolder .ARTIFACT :
322- artifact_path = get_artifact_path (oci_model .custom_metadata_list )
323- else :
324- base_model = self .ds_client .get_model (base_model_ocid ).data
325- artifact_path = get_artifact_path (base_model .custom_metadata_list )
326- else :
327- logger .info (f"Loading { config_file_name } for model { oci_model .id } ..." )
328- artifact_path = get_artifact_path (oci_model .custom_metadata_list )
362+ artifact_path = get_artifact_path (oci_model .custom_metadata_list )
329363 if not artifact_path :
330364 logger .debug (
331365 f"Failed to get artifact path from custom metadata for the model: { model_id } "
@@ -340,7 +374,7 @@ def get_config(
340374 config_file_path = os .path .join (config_path , config_file_name )
341375 if is_path_exists (config_file_path ):
342376 try :
343- logger .debug (
377+ logger .info (
344378 f"Loading config: `{ config_file_name } ` from `{ config_path } `"
345379 )
346380 config = load_config (
@@ -361,6 +395,85 @@ def get_config(
361395
362396 return ModelConfigResult (config = config , model_details = oci_model )
363397
398+ def get_container_image (self , container_type : str = None ) -> str :
399+ """
400+ Gets the latest smc container complete image name from the given container type.
401+
402+ Parameters
403+ ----------
404+ container_type: str
405+ type of container, can be either odsc-vllm-serving, odsc-llm-fine-tuning, odsc-llm-evaluate
406+
407+ Returns
408+ -------
409+ str:
410+ A complete container name along with version. ex: dsmc://odsc-vllm-serving:0.7.4.1
411+ """
412+
413+ containers = self .list_service_containers ()
414+ container = next (
415+ (c for c in containers if c .is_latest and c .family_name == container_type ),
416+ None ,
417+ )
418+ if not container :
419+ raise AquaValueError (f"Invalid container type : { container_type } " )
420+ container_image = (
421+ SERVICE_MANAGED_CONTAINER_URI_SCHEME
422+ + container .container_name
423+ + ":"
424+ + container .tag
425+ )
426+ return container_image
427+
428+ @cached (cache = TTLCache (maxsize = 20 , ttl = timedelta (minutes = 30 ), timer = datetime .now ))
429+ def list_service_containers (self ) -> List [ContainerSummary ]:
430+ """
431+ List containers from containers.conf in OCI Datascience control plane
432+ """
433+ containers = self .ds_client .list_containers ().data
434+ return containers
435+
436+ def get_container_config (self ) -> AquaContainerConfig :
437+ """
438+ Fetches latest containers from containers.conf in OCI Datascience control plane
439+
440+ Returns
441+ -------
442+ AquaContainerConfig
443+ An Object that contains latest container info for the given container family
444+
445+ """
446+ return AquaContainerConfig .from_service_config (
447+ service_containers = self .list_service_containers ()
448+ )
449+
450+ def get_container_config_item (
451+ self , container_family : str
452+ ) -> AquaContainerConfigItem :
453+ """
454+ Fetches latest container for given container_family_name from containers.conf in OCI Datascience control plane
455+
456+ Returns
457+ -------
458+ AquaContainerConfigItem
459+ An Object that contains latest container info for the given container family
460+
461+ """
462+
463+ aqua_container_config = self .get_container_config ()
464+ inference_config = aqua_container_config .inference .values ()
465+ ft_config = aqua_container_config .finetune .values ()
466+ eval_config = aqua_container_config .evaluate .values ()
467+ container = next (
468+ (
469+ container
470+ for container in chain (inference_config , ft_config , eval_config )
471+ if container .family .lower () == container_family .lower ()
472+ ),
473+ None ,
474+ )
475+ return container
476+
364477 @property
365478 def telemetry (self ):
366479 if not self ._telemetry :
0 commit comments