Skip to content

Commit 7e04aaf

Browse files
committed
Fixed model deployment failed return value.
1 parent fbc3677 commit 7e04aaf

File tree

4 files changed

+104
-117
lines changed

4 files changed

+104
-117
lines changed

ads/model/deployment/model_deployment.py

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class LogNotConfiguredError(Exception): # pragma: no cover
8585
pass
8686

8787

88-
class ModelDeploymentFailedError(Exception): # pragma: no cover
88+
class ModelDeploymentPredictError(Exception): # pragma: no cover
8989
pass
9090

9191

@@ -607,31 +607,24 @@ def deploy(
607607
-------
608608
ModelDeployment
609609
The instance of ModelDeployment.
610-
611-
Raises
612-
------
613-
ModelDeploymentFailedError
614-
If model deployment fails to deploy
615610
"""
616611
create_model_deployment_details = (
617612
self._build_model_deployment_details()
618613
if self._spec
619614
else self.properties.build()
620615
)
621616

622-
response = self.dsc_model_deployment.create(
623-
create_model_deployment_details=create_model_deployment_details,
624-
wait_for_completion=wait_for_completion,
625-
max_wait_time=max_wait_time,
626-
poll_interval=poll_interval,
627-
)
628-
629-
if response.lifecycle_state == State.FAILED.name:
630-
raise ModelDeploymentFailedError(
631-
f"Model deployment {response.id} failed to deploy: {response.lifecycle_details}"
617+
try:
618+
response = self.dsc_model_deployment.create(
619+
create_model_deployment_details=create_model_deployment_details,
620+
wait_for_completion=wait_for_completion,
621+
max_wait_time=max_wait_time,
622+
poll_interval=poll_interval,
632623
)
633-
634-
return self._update_from_oci_model(response)
624+
except:
625+
raise
626+
finally:
627+
return self._update_from_oci_model(response)
635628

636629
def delete(
637630
self,
@@ -657,12 +650,16 @@ def delete(
657650
ModelDeployment
658651
The instance of ModelDeployment.
659652
"""
660-
response = self.dsc_model_deployment.delete(
661-
wait_for_completion=wait_for_completion,
662-
max_wait_time=max_wait_time,
663-
poll_interval=poll_interval,
664-
)
665-
return self._update_from_oci_model(response)
653+
try:
654+
response = self.dsc_model_deployment.delete(
655+
wait_for_completion=wait_for_completion,
656+
max_wait_time=max_wait_time,
657+
poll_interval=poll_interval,
658+
)
659+
except:
660+
raise
661+
finally:
662+
return self._update_from_oci_model(response)
666663

667664
def update(
668665
self,
@@ -718,14 +715,17 @@ def update(
718715
else self._update_model_deployment_details(**kwargs)
719716
)
720717

721-
response = self.dsc_model_deployment.update(
722-
update_model_deployment_details=update_model_deployment_details,
723-
wait_for_completion=wait_for_completion,
724-
max_wait_time=max_wait_time,
725-
poll_interval=poll_interval,
726-
)
727-
728-
return self._update_from_oci_model(response)
718+
try:
719+
response = self.dsc_model_deployment.update(
720+
update_model_deployment_details=update_model_deployment_details,
721+
wait_for_completion=wait_for_completion,
722+
max_wait_time=max_wait_time,
723+
poll_interval=poll_interval,
724+
)
725+
except:
726+
raise
727+
finally:
728+
return self._update_from_oci_model(response)
729729

730730
def watch(
731731
self,
@@ -890,6 +890,10 @@ def predict(
890890
Prediction results.
891891
892892
"""
893+
if self.sync().lifecycle_state != "ACTIVE":
894+
raise ModelDeploymentPredictError(
895+
"Model Deployment must be in `ACTIVE` state before predict can be called."
896+
)
893897
endpoint = f"{self.url}/predict"
894898
signer = authutil.default_signer()["signer"]
895899
header = {
@@ -953,7 +957,7 @@ def predict(
953957
except oci.exceptions.ServiceError as ex:
954958
# When bandwidth exceeds the allocated value, TooManyRequests error (429) will be raised by oci backend.
955959
if ex.status == 429:
956-
bandwidth_mbps = self.infrastructure.bandwidth_mbps or MODEL_DEPLOYMENT_BANDWIDTH_MBPS
960+
bandwidth_mbps = self.infrastructure.bandwidth_mbps or DEFAULT_BANDWIDTH_MBPS
957961
utils.get_logger().warning(
958962
f"Load balancer bandwidth exceeds the allocated {bandwidth_mbps} Mbps."
959963
"To estimate the actual bandwidth, use formula: (payload size in KB) * (estimated requests per second) * 8 / 1024."
@@ -985,13 +989,16 @@ def activate(
985989
ModelDeployment
986990
The instance of ModelDeployment.
987991
"""
988-
response = self.dsc_model_deployment.activate(
989-
wait_for_completion=wait_for_completion,
990-
max_wait_time=max_wait_time,
991-
poll_interval=poll_interval,
992-
)
993-
994-
return self._update_from_oci_model(response)
992+
try:
993+
response = self.dsc_model_deployment.activate(
994+
wait_for_completion=wait_for_completion,
995+
max_wait_time=max_wait_time,
996+
poll_interval=poll_interval,
997+
)
998+
except:
999+
raise
1000+
finally:
1001+
return self._update_from_oci_model(response)
9951002

9961003
def deactivate(
9971004
self,
@@ -1017,13 +1024,16 @@ def deactivate(
10171024
ModelDeployment
10181025
The instance of ModelDeployment.
10191026
"""
1020-
response = self.dsc_model_deployment.deactivate(
1021-
wait_for_completion=wait_for_completion,
1022-
max_wait_time=max_wait_time,
1023-
poll_interval=poll_interval,
1024-
)
1025-
1026-
return self._update_from_oci_model(response)
1027+
try:
1028+
response = self.dsc_model_deployment.deactivate(
1029+
wait_for_completion=wait_for_completion,
1030+
max_wait_time=max_wait_time,
1031+
poll_interval=poll_interval,
1032+
)
1033+
except:
1034+
raise
1035+
finally:
1036+
return self._update_from_oci_model(response)
10271037

10281038
def _log_details(self, log_type: str = ModelDeploymentLogType.ACCESS):
10291039
"""Gets log details for the provided `log_type`.

ads/model/generic_model.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2315,16 +2315,20 @@ def deploy(
23152315
.with_runtime(runtime)
23162316
)
23172317

2318-
self.model_deployment = model_deployment.deploy(
2319-
wait_for_completion=wait_for_completion,
2320-
max_wait_time=max_wait_time,
2321-
poll_interval=poll_interval,
2322-
)
2323-
self._summary_status.update_status(
2324-
detail="Deployed the model",
2325-
status=self.model_deployment.state.name.upper(),
2326-
)
2327-
return self.model_deployment
2318+
try:
2319+
self.model_deployment = model_deployment.deploy(
2320+
wait_for_completion=wait_for_completion,
2321+
max_wait_time=max_wait_time,
2322+
poll_interval=poll_interval,
2323+
)
2324+
self._summary_status.update_status(
2325+
detail="Deployed the model",
2326+
status=self.model_deployment.state.name.upper(),
2327+
)
2328+
except:
2329+
raise
2330+
finally:
2331+
return self.model_deployment
23282332

23292333
def prepare_save_deploy(
23302334
self,
@@ -2792,25 +2796,29 @@ def restart_deployment(
27922796
"""
27932797
if not self.model_deployment:
27942798
raise ValueError("Use `deploy()` method to start model deployment.")
2795-
logger.info(
2796-
f"Deactivating model deployment {self.model_deployment.model_deployment_id}."
2797-
)
2798-
self.model_deployment.deactivate(
2799-
max_wait_time=max_wait_time, poll_interval=poll_interval
2800-
)
2801-
logger.info(
2802-
f"Model deployment {self.model_deployment.model_deployment_id} has successfully been deactivated."
2803-
)
2804-
logger.info(
2805-
f"Activating model deployment {self.model_deployment.model_deployment_id}."
2806-
)
2807-
self.model_deployment.activate(
2808-
max_wait_time=max_wait_time, poll_interval=poll_interval
2809-
)
2810-
logger.info(
2811-
f"Model deployment {self.model_deployment.model_deployment_id} has successfully been activated."
2812-
)
2813-
return self.model_deployment
2799+
try:
2800+
logger.info(
2801+
f"Deactivating model deployment {self.model_deployment.model_deployment_id}."
2802+
)
2803+
self.model_deployment.deactivate(
2804+
max_wait_time=max_wait_time, poll_interval=poll_interval
2805+
)
2806+
logger.info(
2807+
f"Model deployment {self.model_deployment.model_deployment_id} has successfully been deactivated."
2808+
)
2809+
logger.info(
2810+
f"Activating model deployment {self.model_deployment.model_deployment_id}."
2811+
)
2812+
self.model_deployment.activate(
2813+
max_wait_time=max_wait_time, poll_interval=poll_interval
2814+
)
2815+
logger.info(
2816+
f"Model deployment {self.model_deployment.model_deployment_id} has successfully been activated."
2817+
)
2818+
except:
2819+
raise
2820+
finally:
2821+
return self.model_deployment
28142822

28152823
@class_or_instance_method
28162824
def delete(

ads/model/service/oci_datascience_model_deployment.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ def activate(
212212
f"Error while trying to activate model deployment: {self.id}"
213213
)
214214
raise e
215+
finally:
216+
return self.sync()
215217

216218
return self.sync()
217219
else:
@@ -264,6 +266,8 @@ def create(
264266
f"Error while trying to create model deployment: {self.id}"
265267
)
266268
raise e
269+
finally:
270+
return self.sync()
267271

268272
return self.sync()
269273

@@ -328,6 +332,8 @@ def deactivate(
328332
f"Error while trying to deactivate model deployment: {self.id}"
329333
)
330334
raise e
335+
finally:
336+
return self.sync()
331337

332338
return self.sync()
333339
else:
@@ -399,6 +405,8 @@ def delete(
399405
f"Error while trying to delete model deployment: {self.id}"
400406
)
401407
raise e
408+
finally:
409+
return self.sync()
402410

403411
return self.sync()
404412

@@ -454,8 +462,8 @@ def update(
454462
except Exception as e:
455463
logger.error(f"Error while trying to update model deployment: {self.id}")
456464
raise e
457-
458-
return self.sync()
465+
finally:
466+
return self.sync()
459467

460468
@classmethod
461469
def list(

tests/unitary/default_setup/model_deployment/test_model_deployment_v2.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from ads.model.deployment.model_deployment import (
2222
ModelDeployment,
2323
ModelDeploymentLogType,
24-
ModelDeploymentFailedError,
2524
)
2625
from ads.model.deployment.model_deployment_infrastructure import (
2726
ModelDeploymentInfrastructure,
@@ -1148,44 +1147,6 @@ def test_deploy(
11481147
mock_create_model_deployment.assert_called_with(create_model_deployment_details)
11491148
mock_sync.assert_called()
11501149

1151-
@patch.object(OCIDataScienceMixin, "sync")
1152-
@patch.object(
1153-
oci.data_science.DataScienceClient,
1154-
"create_model_deployment",
1155-
)
1156-
@patch.object(DataScienceModel, "create")
1157-
def test_deploy_failed(
1158-
self, mock_create, mock_create_model_deployment, mock_sync
1159-
):
1160-
dsc_model = MagicMock()
1161-
dsc_model.id = "fakeid.datasciencemodel.oc1.iad.xxx"
1162-
mock_create.return_value = dsc_model
1163-
response = oci.response.Response(
1164-
status=MagicMock(),
1165-
headers=MagicMock(),
1166-
request=MagicMock(),
1167-
data=oci.data_science.models.ModelDeployment(
1168-
id="test_model_deployment_id",
1169-
lifecycle_state="FAILED",
1170-
lifecycle_details="The specified log object is not found or user is not authorized.",
1171-
),
1172-
)
1173-
mock_sync.return_value = response.data
1174-
model_deployment = self.initialize_model_deployment()
1175-
create_model_deployment_details = (
1176-
model_deployment._build_model_deployment_details()
1177-
)
1178-
with pytest.raises(
1179-
ModelDeploymentFailedError,
1180-
match=f"Model deployment {response.data.id} failed to deploy: {response.data.lifecycle_details}",
1181-
):
1182-
model_deployment.deploy(wait_for_completion=False)
1183-
mock_create.assert_called()
1184-
mock_create_model_deployment.assert_called_with(
1185-
create_model_deployment_details
1186-
)
1187-
mock_sync.assert_called()
1188-
11891150
@patch.object(
11901151
OCIDataScienceModelDeployment,
11911152
"activate",

0 commit comments

Comments
 (0)