1515from ads .aqua import ODSC_MODEL_COMPARTMENT_OCID , logger
1616from ads .aqua .app import AquaApp
1717from ads .aqua .common .enums import (
18+ CustomInferenceContainerTypeFamily ,
1819 FineTuningContainerTypeFamily ,
1920 InferenceContainerTypeFamily ,
2021 Tags ,
@@ -377,8 +378,10 @@ def delete_model(self, model_id):
377378 f"Failed to delete model:{ model_id } . Only registered models or finetuned model can be deleted."
378379 )
379380
380- @telemetry (entry_point = "plugin=model&action=delete" , name = "aqua" )
381- def edit_registered_model (self , id , inference_container , enable_finetuning , task ):
381+ @telemetry (entry_point = "plugin=model&action=edit" , name = "aqua" )
382+ def edit_registered_model (
383+ self , id , inference_container , inference_container_uri , enable_finetuning , task
384+ ):
382385 """Edits the default config of unverified registered model.
383386
384387 Parameters
@@ -387,6 +390,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
387390 The model OCID.
388391 inference_container: str.
389392 The inference container family name
393+ inference_container_uri: str
394+ The inference container uri for embedding models
390395 enable_finetuning: str
391396 Flag to enable or disable finetuning over the model. Defaults to None
392397 task:
@@ -402,19 +407,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
402407 if ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None ):
403408 if ds_model .freeform_tags .get (Tags .AQUA_SERVICE_MODEL_TAG , None ):
404409 raise AquaRuntimeError (
405- f"Failed to edit model: { id } . Only registered unverified models can be edited."
410+ " Only registered unverified models can be edited."
406411 )
407412 else :
408413 custom_metadata_list = ds_model .custom_metadata_list
409414 freeform_tags = ds_model .freeform_tags
410415 if inference_container :
411- custom_metadata_list .add (
412- key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
413- value = inference_container ,
414- category = MetadataCustomCategory .OTHER ,
415- description = "Deployment container mapping for SMC" ,
416- replace = True ,
417- )
416+ if (
417+ inference_container in CustomInferenceContainerTypeFamily
418+ and inference_container_uri is None
419+ ):
420+ raise AquaRuntimeError (
421+ "Inference container URI must be provided."
422+ )
423+ else :
424+ custom_metadata_list .add (
425+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
426+ value = inference_container ,
427+ category = MetadataCustomCategory .OTHER ,
428+ description = "Deployment container mapping for SMC" ,
429+ replace = True ,
430+ )
431+ if inference_container_uri :
432+ if (
433+ inference_container in CustomInferenceContainerTypeFamily
434+ or inference_container is None
435+ ):
436+ custom_metadata_list .add (
437+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER_URI ,
438+ value = inference_container_uri ,
439+ category = MetadataCustomCategory .OTHER ,
440+ description = f"Inference container URI for { ds_model .display_name } " ,
441+ replace = True ,
442+ )
443+ else :
444+ raise AquaRuntimeError (
445+ f"Inference container URI can be edited only with container values: { CustomInferenceContainerTypeFamily .values ()} "
446+ )
447+
418448 if enable_finetuning is not None :
419449 if enable_finetuning .lower () == "true" :
420450 custom_metadata_list .add (
@@ -449,9 +479,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
449479 )
450480 AquaApp ().update_model (id , update_model_details )
451481 else :
452- raise AquaRuntimeError (
453- f"Failed to edit model:{ id } . Only registered unverified models can be edited."
454- )
482+ raise AquaRuntimeError ("Only registered unverified models can be edited." )
455483
456484 def _fetch_metric_from_metadata (
457485 self ,
@@ -870,8 +898,7 @@ def _create_model_catalog_entry(
870898 # only add cmd vars if inference container is not an SMC
871899 if (
872900 inference_container not in smc_container_set
873- and inference_container
874- == InferenceContainerTypeFamily .AQUA_TEI_CONTAINER_FAMILY
901+ and inference_container in CustomInferenceContainerTypeFamily .values ()
875902 ):
876903 cmd_vars = generate_tei_cmd_var (os_path )
877904 metadata .add (
@@ -1328,7 +1355,9 @@ def _download_model_from_hf(
13281355 if local_dir :
13291356 local_dir = os .path .join (local_dir , model_name )
13301357 os .makedirs (local_dir , exist_ok = True )
1331- snapshot_download (
1358+
1359+ # if local_dir is not set, the return value points to the cached data folder
1360+ local_dir = snapshot_download (
13321361 repo_id = model_name ,
13331362 local_dir = local_dir ,
13341363 allow_patterns = allow_patterns ,
@@ -1364,7 +1393,7 @@ def register(
13641393 ignore_patterns (list): Model files matching any of the patterns are not downloaded.
13651394 Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
13661395 Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1367- delete_from_local (bool): Deletes downloaded files from local machine after model is successfully
1396+ cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
13681397 registered. Set to True by default.
13691398
13701399 Returns:
@@ -1477,7 +1506,7 @@ def register(
14771506
14781507 if (
14791508 import_model_details .download_from_hf
1480- and import_model_details .delete_from_local
1509+ and import_model_details .cleanup_model_cache
14811510 ):
14821511 cleanup_local_hf_model_artifact (
14831512 model_name = model_name , local_dir = import_model_details .local_dir
0 commit comments