|
10 | 10 | import oci |
11 | 11 | from cachetools import TTLCache |
12 | 12 | from huggingface_hub import snapshot_download |
13 | | -from oci.data_science.models import JobRun, Model |
| 13 | +from oci.data_science.models import JobRun, Model, UpdateModelDetails, Metadata |
14 | 14 |
|
15 | 15 | from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger |
16 | 16 | from ads.aqua.app import AquaApp |
17 | | -from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags |
| 17 | +from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags, FineTuningContainerTypeFamily |
18 | 18 | from ads.aqua.common.errors import AquaRuntimeError, AquaValueError |
19 | 19 | from ads.aqua.common.utils import ( |
20 | 20 | LifecycleStatus, |
|
75 | 75 | TENANCY_OCID, |
76 | 76 | ) |
77 | 77 | from ads.model import DataScienceModel |
78 | | -from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem |
| 78 | +from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem, MetadataCustomCategory |
79 | 79 | from ads.telemetry import telemetry |
80 | 80 |
|
81 | 81 |
|
@@ -332,6 +332,72 @@ def delete_registered_model(self,model_id): |
332 | 332 | else: |
333 | 333 | raise AquaRuntimeError(f"Failed to delete model:{model_id}. Only registered models can be deleted.") |
334 | 334 |
|
| 335 | + @telemetry(entry_point="plugin=model&action=delete", name="aqua") |
| 336 | + def edit_registered_model(self,id,inference_container,enable_finetuning,task): |
| 337 | + """Edits the default config of unverified registered model. |
| 338 | +
|
| 339 | + Parameters |
| 340 | + ---------- |
| 341 | + id: str |
| 342 | + The model OCID. |
| 343 | + inference_container: str. |
| 344 | + The inference container family name |
| 345 | + enable_finetuning: str |
| 346 | + Flag to enable or disable finetuning over the model. Defaults to None |
| 347 | +
|
| 348 | + Returns |
| 349 | + ------- |
| 350 | + Model: |
| 351 | + The instance of oci.data_science.models.Model. |
| 352 | +
|
| 353 | + """ |
| 354 | + ds_model=DataScienceModel.from_id(id) |
| 355 | + if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM,None): |
| 356 | + if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG,None): |
| 357 | + raise AquaRuntimeError(f"Failed to edit model:{id}. Only registered unverified models can be edited.") |
| 358 | + else: |
| 359 | + custom_metadata_list=ds_model.custom_metadata_list |
| 360 | + freeform_tags=ds_model.freeform_tags |
| 361 | + if inference_container: |
| 362 | + custom_metadata_list.add(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, |
| 363 | + value=inference_container, |
| 364 | + category=MetadataCustomCategory.OTHER, |
| 365 | + description="Deployment container mapping for SMC", |
| 366 | + replace=True |
| 367 | + ) |
| 368 | + if enable_finetuning is not None: |
| 369 | + if enable_finetuning.lower()=="true": |
| 370 | + custom_metadata_list.add(key=ModelCustomMetadataFields.FINETUNE_CONTAINER, |
| 371 | + value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY, |
| 372 | + category=MetadataCustomCategory.OTHER, |
| 373 | + description="Fine-tuning container mapping for SMC", |
| 374 | + replace=True |
| 375 | + ) |
| 376 | + freeform_tags.update({Tags.READY_TO_FINE_TUNE:"true"}) |
| 377 | + elif enable_finetuning.lower()=="false": |
| 378 | + try: |
| 379 | + custom_metadata_list.remove(ModelCustomMetadataFields.FINETUNE_CONTAINER) |
| 380 | + freeform_tags.pop(Tags.READY_TO_FINE_TUNE) |
| 381 | + except Exception as ex: |
| 382 | + raise AquaRuntimeError(f"The given model already doesn't support finetuning: {ex}") |
| 383 | + |
| 384 | + custom_metadata_list.remove("modelDescription") |
| 385 | + if task: |
| 386 | + freeform_tags.update({"task":task}) |
| 387 | + |
| 388 | + updated_custom_metadata_list = [ |
| 389 | + Metadata(**metadata) |
| 390 | + for metadata in custom_metadata_list.to_dict()["data"] |
| 391 | + ] |
| 392 | + update_model_details = UpdateModelDetails( |
| 393 | + custom_metadata_list=updated_custom_metadata_list, |
| 394 | + freeform_tags=freeform_tags |
| 395 | + ) |
| 396 | + return self.ds_client.update_model(id,update_model_details).data |
| 397 | + else: |
| 398 | + raise AquaRuntimeError(f"Failed to edit model:{id}. Only registered unverified models can be deleted.") |
| 399 | + |
| 400 | + |
335 | 401 | def _fetch_metric_from_metadata( |
336 | 402 | self, |
337 | 403 | custom_metadata_list: ModelCustomMetadata, |
|
0 commit comments