2020from ads .common import utils
2121from ads .common .extended_enum import ExtendedEnumMeta
2222from ads .common .object_storage_details import ObjectStorageDetails
23- from ads .config import (
24- AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET ,
25- )
2623from ads .config import (
2724 COMPARTMENT_OCID ,
2825 PROJECT_OCID ,
26+ AQUA_SERVICE_MODELS_BUCKET as SERVICE_MODELS_BUCKET ,
2927)
3028from ads .feature_engineering .schema import Schema
3129from ads .jobs .builders .base import Builder
4846
4947logger = logging .getLogger (__name__ )
5048
49+
5150_MAX_ARTIFACT_SIZE_IN_BYTES = 2147483648 # 2GB
5251MODEL_BY_REFERENCE_VERSION = "1.0"
5352MODEL_BY_REFERENCE_JSON_FILE_NAME = "model_description.json"
@@ -65,8 +64,8 @@ def __init__(self, max_artifact_size: str):
6564
6665class BucketNotVersionedError (Exception ): # pragma: no cover
6766 def __init__ (
68- self ,
69- msg = "Model artifact bucket is not versioned. Enable versioning on the bucket to proceed with model creation by reference." ,
67+ self ,
68+ msg = "Model artifact bucket is not versioned. Enable versioning on the bucket to proceed with model creation by reference." ,
7069 ):
7170 super ().__init__ (msg )
7271
@@ -527,6 +526,7 @@ class DataScienceModel(Builder):
527526 Sets path details for models created by reference. Input can be either a dict, string or json file and
528527 the schema is dictated by model_file_description_schema.json
529528
529+
530530 Examples
531531 --------
532532 >>> ds_model = (DataScienceModel()
@@ -796,7 +796,7 @@ def defined_tags(self) -> Dict[str, Dict[str, object]]:
796796 return self .get_spec (self .CONST_DEFINED_TAG )
797797
798798 def with_defined_tags (
799- self , ** kwargs : Dict [str , Dict [str , object ]]
799+ self , ** kwargs : Dict [str , Dict [str , object ]]
800800 ) -> "DataScienceModel" :
801801 """Sets defined tags.
802802
@@ -877,7 +877,7 @@ def defined_metadata_list(self) -> ModelTaxonomyMetadata:
877877 return self .get_spec (self .CONST_DEFINED_METADATA )
878878
879879 def with_defined_metadata_list (
880- self , metadata : Union [ModelTaxonomyMetadata , Dict ]
880+ self , metadata : Union [ModelTaxonomyMetadata , Dict ]
881881 ) -> "DataScienceModel" :
882882 """Sets model taxonomy (defined) metadata.
883883
@@ -901,7 +901,7 @@ def custom_metadata_list(self) -> ModelCustomMetadata:
901901 return self .get_spec (self .CONST_CUSTOM_METADATA )
902902
903903 def with_custom_metadata_list (
904- self , metadata : Union [ModelCustomMetadata , Dict ]
904+ self , metadata : Union [ModelCustomMetadata , Dict ]
905905 ) -> "DataScienceModel" :
906906 """Sets model custom metadata.
907907
@@ -925,7 +925,7 @@ def provenance_metadata(self) -> ModelProvenanceMetadata:
925925 return self .get_spec (self .CONST_PROVENANCE_METADATA )
926926
927927 def with_provenance_metadata (
928- self , metadata : Union [ModelProvenanceMetadata , Dict ]
928+ self , metadata : Union [ModelProvenanceMetadata , Dict ]
929929 ) -> "DataScienceModel" :
930930 """Sets model provenance metadata.
931931
@@ -1018,7 +1018,7 @@ def model_file_description(self) -> dict:
10181018 return self .get_spec (self .CONST_MODEL_FILE_DESCRIPTION )
10191019
10201020 def with_model_file_description (
1021- self , json_dict : dict = None , json_string : str = None , json_uri : str = None
1021+ self , json_dict : dict = None , json_string : str = None , json_uri : str = None
10221022 ):
10231023 """Sets the json file description for model passed by reference
10241024 Parameters
@@ -1041,7 +1041,7 @@ def with_model_file_description(
10411041 elif json_string :
10421042 json_data = json .loads (json_string )
10431043 elif json_uri :
1044- with open (json_uri ) as json_file :
1044+ with open (json_uri , "r" ) as json_file :
10451045 json_data = json .load (json_file )
10461046 else :
10471047 raise ValueError ("Must provide either a valid json string or URI location." )
@@ -1256,15 +1256,15 @@ def create(self, **kwargs) -> "DataScienceModel":
12561256 return self
12571257
12581258 def upload_artifact (
1259- self ,
1260- bucket_uri : Optional [str ] = None ,
1261- auth : Optional [Dict ] = None ,
1262- region : Optional [str ] = None ,
1263- overwrite_existing_artifact : Optional [bool ] = True ,
1264- remove_existing_artifact : Optional [bool ] = True ,
1265- timeout : Optional [int ] = None ,
1266- parallel_process_count : int = utils .DEFAULT_PARALLEL_PROCESS_COUNT ,
1267- model_by_reference : Optional [bool ] = False ,
1259+ self ,
1260+ bucket_uri : Optional [str ] = None ,
1261+ auth : Optional [Dict ] = None ,
1262+ region : Optional [str ] = None ,
1263+ overwrite_existing_artifact : Optional [bool ] = True ,
1264+ remove_existing_artifact : Optional [bool ] = True ,
1265+ timeout : Optional [int ] = None ,
1266+ parallel_process_count : int = utils .DEFAULT_PARALLEL_PROCESS_COUNT ,
1267+ model_by_reference : Optional [bool ] = False ,
12681268 ) -> None :
12691269 """Uploads model artifacts to the model catalog.
12701270
@@ -1334,7 +1334,7 @@ def upload_artifact(
13341334 bucket_uri = self .artifact
13351335
13361336 if not model_by_reference and (
1337- bucket_uri or utils .folder_size (self .artifact ) > _MAX_ARTIFACT_SIZE_IN_BYTES
1337+ bucket_uri or utils .folder_size (self .artifact ) > _MAX_ARTIFACT_SIZE_IN_BYTES
13381338 ):
13391339 if not bucket_uri :
13401340 raise ModelArtifactSizeError (
@@ -1405,15 +1405,15 @@ def restore_model(
14051405 )
14061406
14071407 def download_artifact (
1408- self ,
1409- target_dir : str ,
1410- auth : Optional [Dict ] = None ,
1411- force_overwrite : Optional [bool ] = False ,
1412- bucket_uri : Optional [str ] = None ,
1413- region : Optional [str ] = None ,
1414- overwrite_existing_artifact : Optional [bool ] = True ,
1415- remove_existing_artifact : Optional [bool ] = True ,
1416- timeout : Optional [int ] = None ,
1408+ self ,
1409+ target_dir : str ,
1410+ auth : Optional [Dict ] = None ,
1411+ force_overwrite : Optional [bool ] = False ,
1412+ bucket_uri : Optional [str ] = None ,
1413+ region : Optional [str ] = None ,
1414+ overwrite_existing_artifact : Optional [bool ] = True ,
1415+ remove_existing_artifact : Optional [bool ] = True ,
1416+ timeout : Optional [int ] = None ,
14171417 ):
14181418 """Downloads model artifacts from the model catalog.
14191419
@@ -1488,9 +1488,9 @@ def download_artifact(
14881488 )
14891489
14901490 if (
1491- artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES
1492- or bucket_uri
1493- or model_by_reference
1491+ artifact_size > _MAX_ARTIFACT_SIZE_IN_BYTES
1492+ or bucket_uri
1493+ or model_by_reference
14941494 ):
14951495 artifact_downloader = LargeArtifactDownloader (
14961496 dsc_model = self .dsc_model ,
@@ -1536,22 +1536,21 @@ def update(self, **kwargs) -> "DataScienceModel":
15361536 self .dsc_model = self ._to_oci_dsc_model (** kwargs ).update ()
15371537
15381538 logger .debug (f"Updating a model provenance metadata { self .provenance_metadata } " )
1539- if self .provenance_metadata :
1540- try :
1541- self .dsc_model .get_model_provenance ()
1542- self .dsc_model .update_model_provenance (
1543- self .provenance_metadata ._to_oci_metadata ()
1544- )
1545- except ModelProvenanceNotFoundError :
1546- self .dsc_model .create_model_provenance (
1547- self .provenance_metadata ._to_oci_metadata ()
1548- )
1539+ try :
1540+ self .dsc_model .get_model_provenance ()
1541+ self .dsc_model .update_model_provenance (
1542+ self .provenance_metadata ._to_oci_metadata ()
1543+ )
1544+ except ModelProvenanceNotFoundError :
1545+ self .dsc_model .create_model_provenance (
1546+ self .provenance_metadata ._to_oci_metadata ()
1547+ )
15491548
15501549 return self .sync ()
15511550
15521551 def delete (
1553- self ,
1554- delete_associated_model_deployment : Optional [bool ] = False ,
1552+ self ,
1553+ delete_associated_model_deployment : Optional [bool ] = False ,
15551554 ) -> "DataScienceModel" :
15561555 """Removes model from the model catalog.
15571556
@@ -1570,7 +1569,7 @@ def delete(
15701569
15711570 @classmethod
15721571 def list (
1573- cls , compartment_id : str = None , project_id : str = None , ** kwargs
1572+ cls , compartment_id : str = None , project_id : str = None , ** kwargs
15741573 ) -> List ["DataScienceModel" ]:
15751574 """Lists datascience models in a given compartment.
15761575
@@ -1597,7 +1596,7 @@ def list(
15971596
15981597 @classmethod
15991598 def list_df (
1600- cls , compartment_id : str = None , project_id : str = None , ** kwargs
1599+ cls , compartment_id : str = None , project_id : str = None , ** kwargs
16011600 ) -> "pandas.DataFrame" :
16021601 """Lists datascience models in a given compartment.
16031602
@@ -1617,7 +1616,7 @@ def list_df(
16171616 """
16181617 records = []
16191618 for model in OCIDataScienceModel .list_resource (
1620- compartment_id , project_id = project_id , ** kwargs
1619+ compartment_id , project_id = project_id , ** kwargs
16211620 ):
16221621 records .append (
16231622 {
@@ -1660,8 +1659,6 @@ def _init_complex_attributes(self):
16601659 self .with_provenance_metadata (self .provenance_metadata )
16611660 self .with_input_schema (self .input_schema )
16621661 self .with_output_schema (self .output_schema )
1663- # self.with_backup_setting(self.backup_setting)
1664- # self.with_retention_setting(self.retention_setting)
16651662
16661663 def _to_oci_dsc_model (self , ** kwargs ):
16671664 """Creates an `OCIDataScienceModel` instance from the `DataScienceModel`.
@@ -1700,7 +1697,7 @@ def _to_oci_dsc_model(self, **kwargs):
17001697 return OCIDataScienceModel (** dsc_spec )
17011698
17021699 def _update_from_oci_dsc_model (
1703- self , dsc_model : OCIDataScienceModel
1700+ self , dsc_model : OCIDataScienceModel
17041701 ) -> "DataScienceModel" :
17051702 """Update the properties from an OCIDataScienceModel object.
17061703
@@ -1973,12 +1970,12 @@ def _download_file_description_artifact(self) -> Tuple[Union[str, List[str]], in
19731970 return bucket_uri [0 ] if len (bucket_uri ) == 1 else bucket_uri , artifact_size
19741971
19751972 def add_artifact (
1976- self ,
1977- uri : Optional [str ] = None ,
1978- namespace : Optional [str ] = None ,
1979- bucket : Optional [str ] = None ,
1980- prefix : Optional [str ] = None ,
1981- files : Optional [List [str ]] = None ,
1973+ self ,
1974+ uri : Optional [str ] = None ,
1975+ namespace : Optional [str ] = None ,
1976+ bucket : Optional [str ] = None ,
1977+ prefix : Optional [str ] = None ,
1978+ files : Optional [List [str ]] = None ,
19821979 ):
19831980 """
19841981 Adds information about objects in a specified bucket to the model description JSON.
@@ -2127,11 +2124,11 @@ def list_obj_versions_unpaginated():
21272124 self .set_spec (self .CONST_MODEL_FILE_DESCRIPTION , tmp_model_file_description )
21282125
21292126 def remove_artifact (
2130- self ,
2131- uri : Optional [str ] = None ,
2132- namespace : Optional [str ] = None ,
2133- bucket : Optional [str ] = None ,
2134- prefix : Optional [str ] = None ,
2127+ self ,
2128+ uri : Optional [str ] = None ,
2129+ namespace : Optional [str ] = None ,
2130+ bucket : Optional [str ] = None ,
2131+ prefix : Optional [str ] = None ,
21352132 ):
21362133 """
21372134 Removes information about objects in a specified bucket or using a specified URI from the model description JSON.
0 commit comments