@@ -153,6 +153,13 @@ def __init__(self, state: str):
153153 super ().__init__ (msg )
154154
155155
156+ class ArtifactsNotAvailableError (Exception ):
157+ def __init__ (
158+ self , msg = "Model artifacts are either not generated or not available locally."
159+ ):
160+ super ().__init__ (msg )
161+
162+
156163class SerializeModelNotImplementedError (NotImplementedError ): # pragma: no cover
157164 pass
158165
@@ -281,6 +288,8 @@ class GenericModel(MetadataMixin, Introspectable, EvaluatorMixin):
281288 Tests if deployment works in local environment.
282289 upload_artifact(...)
283290 Uploads model artifacts to the provided `uri`.
291+ download_artifact(...)
292+ Downloads model artifacts from the model catalog.
284293
285294
286295 Examples
@@ -1249,6 +1258,9 @@ def verify(
12491258 Dict
12501259 A dictionary which contains prediction results.
12511260 """
1261+ if self .model_artifact is None :
1262+ raise ArtifactsNotAvailableError
1263+
12521264 endpoint = f"http://127.0.0.1:8000/predict"
12531265 data = self ._handle_input_data (data , auto_serialize_data , ** kwargs )
12541266
@@ -1402,6 +1414,117 @@ def from_model_artifact(
14021414
14031415 return model
14041416
1417+ def download_artifact (
1418+ self ,
1419+ artifact_dir : Optional [str ] = None ,
1420+ auth : Optional [Dict ] = None ,
1421+ force_overwrite : Optional [bool ] = False ,
1422+ bucket_uri : Optional [str ] = None ,
1423+ remove_existing_artifact : Optional [bool ] = True ,
1424+ ** kwargs ,
1425+ ) -> "GenericModel" :
1426+ """Downloads model artifacts from the model catalog.
1427+
1428+ Parameters
1429+ ----------
1430+ artifact_dir: (str, optional). Defaults to `None`.
1431+ The artifact directory to store the files needed for deployment.
1432+ Will be created if not exists.
1433+ auth: (Dict, optional). Defaults to None.
1434+ The default authentication is set using `ads.set_auth` API. If you need to override the
1435+ default, use the `ads.common.auth.api_keys` or `ads.common.auth.resource_principal` to create appropriate
1436+ authentication signer and kwargs required to instantiate IdentityClient object.
1437+ force_overwrite: (bool, optional). Defaults to False.
1438+ Whether to overwrite existing files or not.
1439+ bucket_uri: (str, optional). Defaults to None.
1440+ The OCI Object Storage URI where model artifacts will be copied to.
1441+ The `bucket_uri` is only necessary for downloading large artifacts with
1442+ size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`.
1443+ remove_existing_artifact: (bool, optional). Defaults to `True`.
1444+ Whether artifacts uploaded to object storage bucket need to be removed or not.
1445+
1446+ Returns
1447+ -------
1448+ Self
1449+ An instance of `GenericModel` class.
1450+
1451+ Raises
1452+ ------
1453+ ValueError
1454+ If `model_id` is not available in the GenericModel object.
1455+ """
1456+ model_id = self .model_id
1457+ if not model_id :
1458+ raise ValueError (
1459+ "`model_id` is not available, load the GenericModel object first."
1460+ )
1461+
1462+ if not artifact_dir :
1463+ artifact_dir = self .artifact_dir
1464+ artifact_dir = _prepare_artifact_dir (artifact_dir )
1465+
1466+ target_dir = (
1467+ _prepare_artifact_dir ()
1468+ if ObjectStorageDetails .is_oci_path (artifact_dir )
1469+ else artifact_dir
1470+ )
1471+
1472+ dsc_model = DataScienceModel .from_id (model_id )
1473+ dsc_model .download_artifact (
1474+ target_dir = target_dir ,
1475+ force_overwrite = force_overwrite ,
1476+ bucket_uri = bucket_uri ,
1477+ remove_existing_artifact = remove_existing_artifact ,
1478+ auth = auth ,
1479+ region = kwargs .pop ("region" , None ),
1480+ timeout = kwargs .pop ("timeout" , None ),
1481+ )
1482+ model_artifact = ModelArtifact .from_uri (
1483+ uri = target_dir ,
1484+ artifact_dir = artifact_dir ,
1485+ model_file_name = self .model_file_name ,
1486+ force_overwrite = force_overwrite ,
1487+ auth = auth ,
1488+ ignore_conda_error = self .ignore_conda_error ,
1489+ )
1490+ self .dsc_model = dsc_model
1491+ self .local_copy_dir = model_artifact .local_copy_dir
1492+ self .model_artifact = model_artifact
1493+ self .reload_runtime_info ()
1494+
1495+ self ._summary_status .update_status (
1496+ detail = "Generated score.py" ,
1497+ status = ModelState .DONE .value ,
1498+ )
1499+ self ._summary_status .update_status (
1500+ detail = "Generated runtime.yaml" ,
1501+ status = ModelState .DONE .value ,
1502+ )
1503+ self ._summary_status .update_status (
1504+ detail = "Serialized model" , status = ModelState .DONE .value
1505+ )
1506+ self ._summary_status .update_status (
1507+ detail = "Populated metadata(Custom, Taxonomy and Provenance)" ,
1508+ status = ModelState .DONE .value ,
1509+ )
1510+ self ._summary_status .update_status (
1511+ detail = "Local tested .predict from score.py" ,
1512+ status = ModelState .AVAILABLE .value ,
1513+ )
1514+ self ._summary_status .update_action (
1515+ detail = "Local tested .predict from score.py" ,
1516+ action = "" ,
1517+ )
1518+ self ._summary_status .update_status (
1519+ detail = "Conducted Introspect Test" ,
1520+ status = ModelState .AVAILABLE .value ,
1521+ )
1522+ self ._summary_status .update_status (
1523+ detail = "Uploaded artifact to model catalog" ,
1524+ status = ModelState .AVAILABLE .value ,
1525+ )
1526+ return self
1527+
14051528 @classmethod
14061529 def from_model_catalog (
14071530 cls : Type [Self ],
@@ -1414,6 +1537,7 @@ def from_model_catalog(
14141537 bucket_uri : Optional [str ] = None ,
14151538 remove_existing_artifact : Optional [bool ] = True ,
14161539 ignore_conda_error : Optional [bool ] = False ,
1540+ download_artifact : Optional [bool ] = True ,
14171541 ** kwargs ,
14181542 ) -> Self :
14191543 """Loads model from model catalog.
@@ -1443,6 +1567,8 @@ def from_model_catalog(
14431567 Wether artifacts uploaded to object storage bucket need to be removed or not.
14441568 ignore_conda_error: (bool, optional). Defaults to False.
14451569 Parameter to ignore error when collecting conda information.
1570+ download_artifact: (bool, optional). Defaults to True.
1571+ Whether to download the model pickle or checkpoints
14461572 kwargs:
14471573 compartment_id : (str, optional)
14481574 Compartment OCID. If not specified, the value will be taken from the environment variables.
@@ -1475,14 +1601,60 @@ def from_model_catalog(
14751601 artifact_dir = _prepare_artifact_dir (artifact_dir )
14761602
14771603 target_dir = (
1478- artifact_dir
1479- if not ObjectStorageDetails .is_oci_path (artifact_dir )
1480- else tempfile . mkdtemp ()
1604+ _prepare_artifact_dir ()
1605+ if ObjectStorageDetails .is_oci_path (artifact_dir )
1606+ else artifact_dir
14811607 )
14821608 bucket_uri = bucket_uri or (
14831609 artifact_dir if ObjectStorageDetails .is_oci_path (artifact_dir ) else None
14841610 )
14851611 dsc_model = DataScienceModel .from_id (model_id )
1612+
1613+ if not download_artifact :
1614+ result_model = cls (
1615+ artifact_dir = artifact_dir ,
1616+ bucket_uri = bucket_uri ,
1617+ auth = auth ,
1618+ properties = properties ,
1619+ ignore_conda_error = ignore_conda_error ,
1620+ ** kwargs ,
1621+ )
1622+ result_model ._summary_status .update_status (
1623+ detail = "Generated score.py" ,
1624+ status = ModelState .NOTAPPLICABLE .value ,
1625+ )
1626+ result_model ._summary_status .update_status (
1627+ detail = "Generated runtime.yaml" ,
1628+ status = ModelState .NOTAPPLICABLE .value ,
1629+ )
1630+ result_model ._summary_status .update_status (
1631+ detail = "Serialized model" , status = ModelState .NOTAPPLICABLE .value
1632+ )
1633+ result_model ._summary_status .update_status (
1634+ detail = "Populated metadata(Custom, Taxonomy and Provenance)" ,
1635+ status = ModelState .NOTAPPLICABLE .value ,
1636+ )
1637+ result_model ._summary_status .update_status (
1638+ detail = "Local tested .predict from score.py" ,
1639+ status = ModelState .NOTAPPLICABLE .value ,
1640+ )
1641+ result_model ._summary_status .update_action (
1642+ detail = "Local tested .predict from score.py" ,
1643+ action = "Local artifact is not available. "
1644+ "Set load_artifact flag to True while loading the model or "
1645+ "call .download_artifact()." ,
1646+ )
1647+ result_model ._summary_status .update_status (
1648+ detail = "Conducted Introspect Test" ,
1649+ status = ModelState .NOTAPPLICABLE .value ,
1650+ )
1651+ result_model ._summary_status .update_status (
1652+ detail = "Uploaded artifact to model catalog" ,
1653+ status = ModelState .NOTAPPLICABLE .value ,
1654+ )
1655+ result_model .dsc_model = dsc_model
1656+ return result_model
1657+
14861658 dsc_model .download_artifact (
14871659 target_dir = target_dir ,
14881660 force_overwrite = force_overwrite ,
@@ -1536,6 +1708,7 @@ def from_model_deployment(
15361708 bucket_uri : Optional [str ] = None ,
15371709 remove_existing_artifact : Optional [bool ] = True ,
15381710 ignore_conda_error : Optional [bool ] = False ,
1711+ download_artifact : Optional [bool ] = True ,
15391712 ** kwargs ,
15401713 ) -> Self :
15411714 """Loads model from model deployment.
@@ -1565,6 +1738,8 @@ def from_model_deployment(
15651738 Wether artifacts uploaded to object storage bucket need to be removed or not.
15661739 ignore_conda_error: (bool, optional). Defaults to False.
15671740 Parameter to ignore error when collecting conda information.
1741+ download_artifact: (bool, optional). Defaults to True.
1742+ Whether to download the model pickle or checkpoints
15681743 kwargs:
15691744 compartment_id : (str, optional)
15701745 Compartment OCID. If not specified, the value will be taken from the environment variables.
@@ -1608,6 +1783,7 @@ def from_model_deployment(
16081783 bucket_uri = bucket_uri ,
16091784 remove_existing_artifact = remove_existing_artifact ,
16101785 ignore_conda_error = ignore_conda_error ,
1786+ download_artifact = download_artifact ,
16111787 ** kwargs ,
16121788 )
16131789 model ._summary_status .update_status (
@@ -1730,6 +1906,7 @@ def from_id(
17301906 bucket_uri : Optional [str ] = None ,
17311907 remove_existing_artifact : Optional [bool ] = True ,
17321908 ignore_conda_error : Optional [bool ] = False ,
1909+ download_artifact : Optional [bool ] = True ,
17331910 ** kwargs ,
17341911 ) -> Self :
17351912 """Loads model from model OCID or model deployment OCID.
@@ -1757,6 +1934,10 @@ def from_id(
17571934 size is greater than 2GB. Example: `oci://<bucket_name>@<namespace>/prefix/`.
17581935 remove_existing_artifact: (bool, optional). Defaults to `True`.
17591936 Wether artifacts uploaded to object storage bucket need to be removed or not.
1937+ ignore_conda_error: (bool, optional). Defaults to False.
1938+ Parameter to ignore error when collecting conda information.
1939+ download_artifact: (bool, optional). Defaults to True.
1940+ Whether to download the model pickle or checkpoints
17601941 kwargs:
17611942 compartment_id : (str, optional)
17621943 Compartment OCID. If not specified, the value will be taken from the environment variables.
@@ -1780,6 +1961,7 @@ def from_id(
17801961 bucket_uri = bucket_uri ,
17811962 remove_existing_artifact = remove_existing_artifact ,
17821963 ignore_conda_error = ignore_conda_error ,
1964+ download_artifact = download_artifact ,
17831965 ** kwargs ,
17841966 )
17851967 elif DataScienceModelType .MODEL in ocid :
@@ -1793,6 +1975,7 @@ def from_id(
17931975 bucket_uri = bucket_uri ,
17941976 remove_existing_artifact = remove_existing_artifact ,
17951977 ignore_conda_error = ignore_conda_error ,
1978+ download_artifact = download_artifact ,
17961979 ** kwargs ,
17971980 )
17981981 else :
@@ -1924,11 +2107,13 @@ def save(
19242107 ... bucket_uri="oci://my-bucket@my-tenancy/",
19252108 ... overwrite_existing_artifact=True,
19262109 ... remove_existing_artifact=True,
1927- ... remove_existing_artifact=True,
19282110 ... parallel_process_count=9,
19292111 ... )
19302112
19312113 """
2114+ if self .model_artifact is None :
2115+ raise ArtifactsNotAvailableError
2116+
19322117 # Set default display_name if not specified - randomly generated easy to remember name generated
19332118 if not display_name :
19342119 display_name = self ._random_display_name ()
0 commit comments