Skip to content

Commit 5ddef7e

Browse files
authored
ODSC-39737 : allow to use GenericModel.predict() locally (#127)
2 parents 78fb1e7 + d42a2ad commit 5ddef7e

File tree

2 files changed

+70
-27
lines changed

2 files changed

+70
-27
lines changed

ads/model/generic_model.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,6 +2497,7 @@ def predict(
24972497
self,
24982498
data: Any = None,
24992499
auto_serialize_data: bool = False,
2500+
local: bool = False,
25002501
**kwargs,
25012502
) -> Dict[str, Any]:
25022503
"""Returns prediction of input data run against the model deployment endpoint.
@@ -2521,6 +2522,8 @@ def predict(
25212522
Whether to auto serialize input data. Defauls to `False` for GenericModel, and `True` for other frameworks.
25222523
`data` required to be json serializable if `auto_serialize_data=False`.
25232524
If `auto_serialize_data` set to True, data will be serialized before sending to model deployment endpoint.
2525+
local: bool.
2526+
Whether to invoke the prediction locally. Default to False.
25242527
kwargs:
25252528
content_type: str, used to indicate the media type of the resource.
25262529
image: PIL.Image Object or uri for the image.
@@ -2539,10 +2542,21 @@ def predict(
25392542
NotActiveDeploymentError
25402543
If model deployment process was not started or not finished yet.
25412544
ValueError
2542-
If `data` is empty or not JSON serializable.
2545+
If model is not deployed yet or the endpoint information is not available.
25432546
"""
2544-
if not self.model_deployment:
2545-
raise ValueError("Use `deploy()` method to start model deployment.")
2547+
if local:
2548+
return self.verify(
2549+
data=data, auto_serialize_data=auto_serialize_data, **kwargs
2550+
)
2551+
2552+
if not (self.model_deployment and self.model_deployment.url):
2553+
raise ValueError(
2554+
"Error invoking the remote endpoint as the model is not "
2555+
"deployed yet or the endpoint information is not available. "
2556+
"Use `deploy()` method to start model deployment. "
2557+
"If you intend to invoke inference using locally available "
2558+
"model artifact, set parameter `local=True`"
2559+
)
25462560

25472561
current_state = self.model_deployment.state.name.upper()
25482562
if current_state != ModelDeploymentState.ACTIVE.name:

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,20 @@
170170

171171
INFERENCE_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
172172
TRAINING_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
173+
DEFAULT_PYTHON_VERSION = "3.8"
174+
MODEL_FILE_NAME = "fake_model_name"
175+
FAKE_MD_URL = "http://<model-deployment-url>"
176+
177+
178+
def _prepare(model):
179+
model.prepare(
180+
inference_conda_env=INFERENCE_CONDA_ENV,
181+
inference_python_version=DEFAULT_PYTHON_VERSION,
182+
training_conda_env=TRAINING_CONDA_ENV,
183+
training_python_version=DEFAULT_PYTHON_VERSION,
184+
model_file_name=MODEL_FILE_NAME,
185+
force_overwrite=True,
186+
)
173187

174188

175189
class TestEstimator:
@@ -315,14 +329,7 @@ def test_prepare_with_custom_scorepy(self, mock_signer):
315329
@patch("ads.common.auth.default_signer")
316330
def test_verify_without_reload(self, mock_signer):
317331
"""Test verify input data without reload artifacts."""
318-
self.generic_model.prepare(
319-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
320-
inference_python_version="3.6",
321-
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
322-
training_python_version="3.7",
323-
model_file_name="fake_model_name",
324-
force_overwrite=True,
325-
)
332+
_prepare(self.generic_model)
326333
self.generic_model.verify(self.X_test.tolist())
327334

328335
with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
@@ -332,20 +339,10 @@ def test_verify_without_reload(self, mock_signer):
332339
@patch("ads.common.auth.default_signer")
333340
def test_verify(self, mock_signer):
334341
"""Test verify input data"""
335-
self.generic_model.prepare(
336-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
337-
inference_python_version="3.6",
338-
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
339-
training_python_version="3.7",
340-
model_file_name="fake_model_name",
341-
force_overwrite=True,
342-
)
342+
_prepare(self.generic_model)
343343
prediction_1 = self.generic_model.verify(self.X_test.tolist())
344344
assert isinstance(prediction_1, dict), "Failed to verify json payload."
345345

346-
prediction_2 = self.generic_model.verify(self.X_test.tolist())
347-
assert isinstance(prediction_2, dict), "Failed to verify input data."
348-
349346
def test_reload(self):
350347
"""test the reload."""
351348
pass
@@ -637,11 +634,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
637634
== random_name[:-9]
638635
)
639636

637+
@pytest.mark.parametrize("input_data", [(X_test.tolist())])
638+
@patch("ads.common.auth.default_signer")
639+
def test_predict_locally(self, mock_signer, input_data):
640+
_prepare(self.generic_model)
641+
test_result = self.generic_model.predict(data=input_data, local=True)
642+
expected_result = self.generic_model.estimator.predict(input_data).tolist()
643+
assert (
644+
test_result["prediction"] == expected_result
645+
), "Failed to verify input data."
646+
647+
with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
648+
self.generic_model.predict(
649+
data=input_data, local=True, reload_artifacts=False
650+
)
651+
mock_reload.assert_not_called()
652+
640653
@patch.object(ModelDeployment, "predict")
641654
@patch("ads.common.auth.default_signer")
642655
@patch("ads.common.oci_client.OCIClientFactory")
656+
@patch(
657+
"ads.model.deployment.model_deployment.ModelDeployment.url",
658+
return_value=FAKE_MD_URL,
659+
)
643660
def test_predict_with_not_active_deployment_fail(
644-
self, mock_client, mock_signer, mock_predict
661+
self, mock_url, mock_client, mock_signer, mock_predict
645662
):
646663
"""Ensures predict model fails in case of model deployment is not in an active state."""
647664
with pytest.raises(NotActiveDeploymentError):
@@ -661,7 +678,11 @@ def test_predict_with_not_active_deployment_fail(
661678

662679
@patch("ads.common.auth.default_signer")
663680
@patch("ads.common.oci_client.OCIClientFactory")
664-
def test_predict_bytes_success(self, mock_client, mock_signer):
681+
@patch(
682+
"ads.model.deployment.model_deployment.ModelDeployment.url",
683+
return_value=FAKE_MD_URL,
684+
)
685+
def test_predict_bytes_success(self, mock_url, mock_client, mock_signer):
665686
"""Ensures predict model passes with bytes input."""
666687
with patch.object(
667688
ModelDeployment, "state", new_callable=PropertyMock
@@ -670,7 +691,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
670691
with patch.object(ModelDeployment, "predict") as mock_predict:
671692
mock_predict.return_value = {"result": "result"}
672693
self.generic_model.model_deployment = ModelDeployment(
673-
model_deployment_id="test"
694+
model_deployment_id="test",
674695
)
675696
# self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
676697
self.generic_model._as_onnx = False
@@ -683,7 +704,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
683704

684705
@patch("ads.common.auth.default_signer")
685706
@patch("ads.common.oci_client.OCIClientFactory")
686-
def test_predict_success(self, mock_client, mock_signer):
707+
@patch(
708+
"ads.model.deployment.model_deployment.ModelDeployment.url",
709+
return_value=FAKE_MD_URL,
710+
)
711+
def test_predict_success(self, mock_url, mock_client, mock_signer):
687712
"""Ensures predict model passes with valid input parameters."""
688713
with patch.object(
689714
ModelDeployment, "state", new_callable=PropertyMock
@@ -800,7 +825,11 @@ def test_from_model_artifact(
800825

801826
@patch("ads.common.auth.default_signer")
802827
@patch("ads.common.oci_client.OCIClientFactory")
803-
def test_predict_success__serialize_input(self, mock_client, mock_signer):
828+
@patch(
829+
"ads.model.deployment.model_deployment.ModelDeployment.url",
830+
return_value=FAKE_MD_URL,
831+
)
832+
def test_predict_success__serialize_input(self, mock_url, mock_client, mock_signer):
804833
"""Ensures predict model passes with valid input parameters."""
805834

806835
df = pd.DataFrame([1, 2, 3])

0 commit comments

Comments
 (0)