11#!/usr/bin/env python
2- # Copyright (c) 2024 Oracle and/or its affiliates.
2+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44import os
55import pathlib
1515from ads .aqua import ODSC_MODEL_COMPARTMENT_OCID , logger
1616from ads .aqua .app import AquaApp
1717from ads .aqua .common .enums import (
18+ CustomInferenceContainerTypeFamily ,
1819 FineTuningContainerTypeFamily ,
1920 InferenceContainerTypeFamily ,
2021 Tags ,
2324from ads .aqua .common .utils import (
2425 LifecycleStatus ,
2526 _build_resource_identifier ,
27+ cleanup_local_hf_model_artifact ,
2628 copy_model_config ,
2729 create_word_icon ,
2830 generate_tei_cmd_var ,
@@ -376,8 +378,10 @@ def delete_model(self, model_id):
376378 f"Failed to delete model:{ model_id } . Only registered models or finetuned model can be deleted."
377379 )
378380
379- @telemetry (entry_point = "plugin=model&action=delete" , name = "aqua" )
380- def edit_registered_model (self , id , inference_container , enable_finetuning , task ):
381+ @telemetry (entry_point = "plugin=model&action=edit" , name = "aqua" )
382+ def edit_registered_model (
383+ self , id , inference_container , inference_container_uri , enable_finetuning , task
384+ ):
381385 """Edits the default config of unverified registered model.
382386
383387 Parameters
@@ -386,6 +390,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
386390 The model OCID.
387391 inference_container: str.
388392 The inference container family name
393+ inference_container_uri: str
394+ The inference container uri for embedding models
389395 enable_finetuning: str
390396 Flag to enable or disable finetuning over the model. Defaults to None
391397 task:
@@ -401,19 +407,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
401407 if ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None ):
402408 if ds_model .freeform_tags .get (Tags .AQUA_SERVICE_MODEL_TAG , None ):
403409 raise AquaRuntimeError (
404- f"Failed to edit model: { id } . Only registered unverified models can be edited."
410+ " Only registered unverified models can be edited."
405411 )
406412 else :
407413 custom_metadata_list = ds_model .custom_metadata_list
408414 freeform_tags = ds_model .freeform_tags
409415 if inference_container :
410- custom_metadata_list .add (
411- key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
412- value = inference_container ,
413- category = MetadataCustomCategory .OTHER ,
414- description = "Deployment container mapping for SMC" ,
415- replace = True ,
416- )
416+ if (
417+ inference_container in CustomInferenceContainerTypeFamily
418+ and inference_container_uri is None
419+ ):
420+ raise AquaRuntimeError (
421+ "Inference container URI must be provided."
422+ )
423+ else :
424+ custom_metadata_list .add (
425+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
426+ value = inference_container ,
427+ category = MetadataCustomCategory .OTHER ,
428+ description = "Deployment container mapping for SMC" ,
429+ replace = True ,
430+ )
431+ if inference_container_uri :
432+ if (
433+ inference_container in CustomInferenceContainerTypeFamily
434+ or inference_container is None
435+ ):
436+ custom_metadata_list .add (
437+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER_URI ,
438+ value = inference_container_uri ,
439+ category = MetadataCustomCategory .OTHER ,
440+ description = f"Inference container URI for { ds_model .display_name } " ,
441+ replace = True ,
442+ )
443+ else :
444+ raise AquaRuntimeError (
445+ f"Inference container URI can be edited only with container values: { CustomInferenceContainerTypeFamily .values ()} "
446+ )
447+
417448 if enable_finetuning is not None :
418449 if enable_finetuning .lower () == "true" :
419450 custom_metadata_list .add (
@@ -448,9 +479,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
448479 )
449480 AquaApp ().update_model (id , update_model_details )
450481 else :
451- raise AquaRuntimeError (
452- f"Failed to edit model:{ id } . Only registered unverified models can be edited."
453- )
482+ raise AquaRuntimeError ("Only registered unverified models can be edited." )
454483
455484 def _fetch_metric_from_metadata (
456485 self ,
@@ -869,8 +898,7 @@ def _create_model_catalog_entry(
869898 # only add cmd vars if inference container is not an SMC
870899 if (
871900 inference_container not in smc_container_set
872- and inference_container
873- == InferenceContainerTypeFamily .AQUA_TEI_CONTAINER_FAMILY
901+ and inference_container in CustomInferenceContainerTypeFamily .values ()
874902 ):
875903 cmd_vars = generate_tei_cmd_var (os_path )
876904 metadata .add (
@@ -1322,20 +1350,20 @@ def _download_model_from_hf(
13221350 Returns
13231351 -------
13241352 model_artifact_path (str): Location where the model artifacts are downloaded.
1325-
13261353 """
13271354 # Download the model from hub
1328- if not local_dir :
1329- local_dir = os .path .join (os .path .expanduser ("~" ), "cached-model" )
1330- local_dir = os .path .join (local_dir , model_name )
1331- os .makedirs (local_dir , exist_ok = True )
1332- snapshot_download (
1355+ if local_dir :
1356+ local_dir = os .path .join (local_dir , model_name )
1357+ os .makedirs (local_dir , exist_ok = True )
1358+
1359+ # if local_dir is not set, the return value points to the cached data folder
1360+ local_dir = snapshot_download (
13331361 repo_id = model_name ,
13341362 local_dir = local_dir ,
13351363 allow_patterns = allow_patterns ,
13361364 ignore_patterns = ignore_patterns ,
13371365 )
1338- # Upload to object storage and skip .cache/huggingface/ folder
1366+ # Upload to object storage
13391367 model_artifact_path = upload_folder (
13401368 os_path = os_path ,
13411369 local_dir = local_dir ,
@@ -1365,6 +1393,8 @@ def register(
13651393 ignore_patterns (list): Model files matching any of the patterns are not downloaded.
13661394 Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
13671395 Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1396+ cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
1397+ registered. Set to True by default.
13681398
13691399 Returns:
13701400 AquaModel:
@@ -1474,6 +1504,14 @@ def register(
14741504 detail = validation_result .telemetry_model_name ,
14751505 )
14761506
1507+ if (
1508+ import_model_details .download_from_hf
1509+ and import_model_details .cleanup_model_cache
1510+ ):
1511+ cleanup_local_hf_model_artifact (
1512+ model_name = model_name , local_dir = import_model_details .local_dir
1513+ )
1514+
14771515 return AquaModel (** aqua_model_attributes )
14781516
14791517 def _if_show (self , model : DataScienceModel ) -> bool :
0 commit comments