Skip to content

Commit 4eb2deb

Browse files
committed
Merge branch 'main' of https://github.com/oracle/accelerated-data-science into ODSC-47079/fixed_pipeline_step_to_dict
2 parents 1367e4c + 7d39f33 commit 4eb2deb

File tree

6 files changed

+36
-29
lines changed

6 files changed

+36
-29
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,8 @@ def from_id(cls, id: str) -> "ModelDeployment":
13041304
ModelDeployment
13051305
The ModelDeployment instance (self).
13061306
"""
1307-
return cls()._update_from_oci_model(OCIDataScienceModelDeployment.from_id(id))
1307+
oci_model = OCIDataScienceModelDeployment.from_id(id)
1308+
return cls(properties=oci_model)._update_from_oci_model(oci_model)
13081309

13091310
@classmethod
13101311
def from_dict(cls, obj_dict: Dict) -> "ModelDeployment":
@@ -1503,7 +1504,9 @@ def _build_model_deployment_details(self) -> CreateModelDeploymentDetails:
15031504
**create_model_deployment_details
15041505
).to_oci_model(CreateModelDeploymentDetails)
15051506

1506-
def _update_model_deployment_details(self, **kwargs) -> UpdateModelDeploymentDetails:
1507+
def _update_model_deployment_details(
1508+
self, **kwargs
1509+
) -> UpdateModelDeploymentDetails:
15071510
"""Builds UpdateModelDeploymentDetails from model deployment instance.
15081511
15091512
Returns
@@ -1527,7 +1530,7 @@ def _update_model_deployment_details(self, **kwargs) -> UpdateModelDeploymentDet
15271530
return OCIDataScienceModelDeployment(
15281531
**update_model_deployment_details
15291532
).to_oci_model(UpdateModelDeploymentDetails)
1530-
1533+
15311534
def _update_spec(self, **kwargs) -> "ModelDeployment":
15321535
"""Updates model deployment specs from kwargs.
15331536
@@ -1542,7 +1545,7 @@ def _update_spec(self, **kwargs) -> "ModelDeployment":
15421545
Model deployment freeform tags
15431546
defined_tags: (dict)
15441547
Model deployment defined tags
1545-
1548+
15461549
Additional kwargs arguments.
15471550
Can be any attribute that `ads.model.deployment.ModelDeploymentCondaRuntime`, `ads.model.deployment.ModelDeploymentContainerRuntime`
15481551
and `ads.model.deployment.ModelDeploymentInfrastructure` accepts.
@@ -1559,20 +1562,22 @@ def _update_spec(self, **kwargs) -> "ModelDeployment":
15591562
specs = {
15601563
"self": self._spec,
15611564
"runtime": self.runtime._spec,
1562-
"infrastructure": self.infrastructure._spec
1565+
"infrastructure": self.infrastructure._spec,
15631566
}
15641567
sub_set = {
15651568
self.infrastructure.CONST_ACCESS_LOG,
15661569
self.infrastructure.CONST_PREDICT_LOG,
1567-
self.infrastructure.CONST_SHAPE_CONFIG_DETAILS
1570+
self.infrastructure.CONST_SHAPE_CONFIG_DETAILS,
15681571
}
15691572
for spec_value in specs.values():
15701573
for key in spec_value:
15711574
if key in converted_specs:
15721575
if key in sub_set:
15731576
for sub_key in converted_specs[key]:
15741577
converted_sub_key = ads_utils.snake_to_camel(sub_key)
1575-
spec_value[key][converted_sub_key] = converted_specs[key][sub_key]
1578+
spec_value[key][converted_sub_key] = converted_specs[key][
1579+
sub_key
1580+
]
15761581
else:
15771582
spec_value[key] = copy.deepcopy(converted_specs[key])
15781583
self = (
@@ -1616,14 +1621,14 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16161621
infrastructure.CONST_MEMORY_IN_GBS: infrastructure.shape_config_details.get(
16171622
"memory_in_gbs", None
16181623
)
1619-
or infrastructure.shape_config_details.get(
1620-
"memoryInGBs", None
1621-
)
1624+
or infrastructure.shape_config_details.get("memoryInGBs", None)
16221625
or DEFAULT_MEMORY_IN_GBS,
16231626
}
16241627

16251628
if infrastructure.subnet_id:
1626-
instance_configuration[infrastructure.CONST_SUBNET_ID] = infrastructure.subnet_id
1629+
instance_configuration[
1630+
infrastructure.CONST_SUBNET_ID
1631+
] = infrastructure.subnet_id
16271632

16281633
scaling_policy = {
16291634
infrastructure.CONST_POLICY_TYPE: "FIXED_SIZE",
@@ -1638,13 +1643,11 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16381643

16391644
model_id = runtime.model_uri
16401645
if not model_id.startswith("ocid"):
1641-
16421646
from ads.model.datascience_model import DataScienceModel
1643-
1647+
16441648
dsc_model = DataScienceModel(
16451649
name=self.display_name,
1646-
compartment_id=self.infrastructure.compartment_id
1647-
or COMPARTMENT_OCID,
1650+
compartment_id=self.infrastructure.compartment_id or COMPARTMENT_OCID,
16481651
project_id=self.infrastructure.project_id or PROJECT_OCID,
16491652
artifact=runtime.model_uri,
16501653
).create(
@@ -1653,7 +1656,7 @@ def _build_model_deployment_configuration_details(self) -> Dict:
16531656
region=runtime.region,
16541657
overwrite_existing_artifact=runtime.overwrite_existing_artifact,
16551658
remove_existing_artifact=runtime.remove_existing_artifact,
1656-
timeout=runtime.timeout
1659+
timeout=runtime.timeout,
16571660
)
16581661
model_id = dsc_model.id
16591662

ads/model/service/oci_datascience_model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,19 @@
3838
)
3939

4040

41-
class ModelProvenanceNotFoundError(Exception): # pragma: no cover
41+
class ModelProvenanceNotFoundError(Exception): # pragma: no cover
4242
pass
4343

4444

45-
class ModelArtifactNotFoundError(Exception): # pragma: no cover
45+
class ModelArtifactNotFoundError(Exception): # pragma: no cover
4646
pass
4747

4848

49-
class ModelNotSavedError(Exception): # pragma: no cover
49+
class ModelNotSavedError(Exception): # pragma: no cover
5050
pass
5151

5252

53-
class ModelWithActiveDeploymentError(Exception): # pragma: no cover
53+
class ModelWithActiveDeploymentError(Exception): # pragma: no cover
5454
pass
5555

5656

@@ -410,7 +410,7 @@ def export_model_artifact(self, bucket_uri: str, region: str = None):
410410
# Show progress of exporting model artifacts
411411
self._wait_for_work_request(
412412
work_request_id=work_request_id,
413-
num_steps=3,
413+
num_steps=2,
414414
)
415415

416416
@check_for_model_id(
@@ -596,3 +596,7 @@ def _wait_for_work_request(self, work_request_id: str, num_steps: int = 3) -> No
596596
)
597597
else:
598598
break
599+
600+
while i < num_steps:
601+
progress.update()
602+
i += 1

docs/source/user_guide/model_catalog/model_catalog.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,7 +1204,7 @@ If you don't have an Object Storage bucket, create one using the OCI SDK or the
12041204
12051205
Allow service datascience to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
12061206
1207-
Allow service objectstorage to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
1207+
Allow service objectstorage-<region> to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
12081208
12091209
Saving
12101210
======
@@ -1545,4 +1545,3 @@ In the next example, the model that was stored in the model catalog as part of t
15451545
.. code-block:: python3
15461546
15471547
mc.delete_model(mc_model.id)
1548-

docs/source/user_guide/model_registration/large_model_artifact.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ If you don't have an Object Storage bucket, create one using the OCI SDK or the
1313
1414
Allow service datascience to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
1515
16-
Allow service objectstorage to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
16+
Allow service objectstorage-<region> to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
1717
1818
See `API documentation <../../ads.model.html#id10>`__ for more details.
1919

docs/source/user_guide/model_registration/model_load.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ If you don't have an Object Storage bucket, create one using the OCI SDK or the
119119
120120
Allow service datascience to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
121121
122-
Allow service objectstorage to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
122+
Allow service objectstorage-<region> to manage object-family in compartment <compartment> where ALL {target.bucket.name='<bucket_name>'}
123123
124124
The following example loads a model using the large model artifact approach. The ``bucket_uri`` has the following syntax: ``oci://<bucket_name>@<namespace>/<path>/`` See `API documentation <../../ads.model.html#id4>`__ for more details.
125125

@@ -169,4 +169,4 @@ Alternatively the ``.from_id()`` method can be used to load registered or deploy
169169
bucket_uri=<oci://<bucket_name>@<namespace>/prefix/>,
170170
force_overwrite=True,
171171
remove_existing_artifact=True,
172-
)
172+
)

tests/unitary/default_setup/model/test_oci_datascience_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888

8989
class TestOCIDataScienceModel:
9090
def setup_class(cls):
91-
9291
# Mock delete model response
9392
cls.mock_delete_model_response = Response(
9493
data=None, status=None, headers=None, request=None
@@ -229,7 +228,9 @@ def test_delete_success(self, mock_client):
229228
mock_model_deployment.return_value = [
230229
MagicMock(lifecycle_state="ACTIVE", identifier="md_id")
231230
]
232-
with patch("ads.model.deployment.ModelDeployment.from_id") as mock_from_id:
231+
with patch(
232+
"ads.model.deployment.ModelDeployment.from_id"
233+
) as mock_from_id:
233234
with patch.object(OCIDataScienceModel, "sync") as mock_sync:
234235
self.mock_model.delete(delete_associated_model_deployment=True)
235236
mock_from_id.assert_called_with("md_id")
@@ -445,7 +446,7 @@ def test_export_model_artifact(
445446
)
446447
mock_wait_for_work_request.assert_called_with(
447448
work_request_id="work_request_id",
448-
num_steps=3,
449+
num_steps=2,
449450
)
450451

451452
@patch.object(TqdmProgressBar, "update")

0 commit comments

Comments
 (0)