Skip to content

Commit 1d8dd19

Browse files
committed
added attribute in predict() to allow for invoking predict in local
1 parent aec602f commit 1d8dd19

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

ads/model/generic_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,6 +2483,7 @@ def predict(
24832483
self,
24842484
data: Any = None,
24852485
auto_serialize_data: bool = False,
2486+
local: bool = False,
24862487
**kwargs,
24872488
) -> Dict[str, Any]:
24882489
"""Returns prediction of input data run against the model deployment endpoint.
@@ -2507,6 +2508,8 @@ def predict(
25072508
Whether to auto serialize input data. Defauls to `False` for GenericModel, and `True` for other frameworks.
25082509
`data` required to be json serializable if `auto_serialize_data=False`.
25092510
If `auto_serialize_data` set to True, data will be serialized before sending to model deployment endpoint.
2511+
local: bool.
2512+
Whether to invoke the prediction locally. Default to False.
25102513
kwargs:
25112514
content_type: str, used to indicate the media type of the resource.
25122515
image: PIL.Image Object or uri for the image.
@@ -2527,6 +2530,9 @@ def predict(
25272530
ValueError
25282531
If `data` is empty or not JSON serializable.
25292532
"""
2533+
if local:
2534+
return self.verify(data=data, auto_serialize_data=auto_serialize_data, **kwargs)
2535+
25302536
if not self.model_deployment:
25312537
raise ValueError("Use `deploy()` method to start model deployment.")
25322538

tests/unitary/with_extras/model/test_generic_model.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,20 @@
168168
"training_script": None,
169169
}
170170

171-
171+
INFERENCE_CONDA_ENV= "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
172+
TRAINING_CONDA_ENV="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
173+
DEFAULT_PYTHON_VERSION = "3.8"
174+
MODEL_FILE_NAME = "fake_model_name"
175+
176+
def _prepare(model):
177+
model.prepare(
178+
inference_conda_env=INFERENCE_CONDA_ENV,
179+
inference_python_version=DEFAULT_PYTHON_VERSION,
180+
training_conda_env=TRAINING_CONDA_ENV,
181+
training_python_version=DEFAULT_PYTHON_VERSION,
182+
model_file_name=MODEL_FILE_NAME,
183+
force_overwrite=True,
184+
)
172185
class TestEstimator:
173186
def predict(self, x):
174187
return x**2
@@ -300,14 +313,7 @@ def test_prepare_both_conda_env(self, mock_signer):
300313
@patch("ads.common.auth.default_signer")
301314
def test_verify_without_reload(self, mock_signer):
302315
"""Test verify input data without reload artifacts."""
303-
self.generic_model.prepare(
304-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
305-
inference_python_version="3.6",
306-
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
307-
training_python_version="3.7",
308-
model_file_name="fake_model_name",
309-
force_overwrite=True,
310-
)
316+
_prepare(self.generic_model)
311317
self.generic_model.verify(self.X_test.tolist())
312318

313319
with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
@@ -317,20 +323,10 @@ def test_verify_without_reload(self, mock_signer):
317323
@patch("ads.common.auth.default_signer")
318324
def test_verify(self, mock_signer):
319325
"""Test verify input data"""
320-
self.generic_model.prepare(
321-
inference_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1",
322-
inference_python_version="3.6",
323-
training_conda_env="oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1",
324-
training_python_version="3.7",
325-
model_file_name="fake_model_name",
326-
force_overwrite=True,
327-
)
326+
_prepare(self.generic_model)
328327
prediction_1 = self.generic_model.verify(self.X_test.tolist())
329328
assert isinstance(prediction_1, dict), "Failed to verify json payload."
330329

331-
prediction_2 = self.generic_model.verify(self.X_test.tolist())
332-
assert isinstance(prediction_2, dict), "Failed to verify input data."
333-
334330
def test_reload(self):
335331
"""test the reload."""
336332
pass
@@ -622,6 +618,21 @@ def test_deploy_with_default_display_name(self, mock_deploy):
622618
== random_name[:-9]
623619
)
624620

621+
@pytest.mark.parametrize(
622+
"input_data",
623+
[(X_test.tolist())]
624+
)
625+
@patch("ads.common.auth.default_signer")
626+
def test_predict_locally(self, mock_signer, input_data):
627+
_prepare(self.generic_model)
628+
test_result = self.generic_model.predict(data=input_data, local=True)
629+
expected_result = self.generic_model.estimator.predict(input_data).tolist()
630+
assert test_result['prediction'] == expected_result, "Failed to verify input data."
631+
632+
with patch("ads.model.artifact.ModelArtifact.reload") as mock_reload:
633+
self.generic_model.predict(data=input_data, local=True, reload_artifacts=False)
634+
mock_reload.assert_not_called()
635+
625636
@patch.object(ModelDeployment, "predict")
626637
@patch("ads.common.auth.default_signer")
627638
@patch("ads.common.oci_client.OCIClientFactory")

0 commit comments

Comments
 (0)