170170
171171INFERENCE_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
172172TRAINING_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
173+ DEFAULT_PYTHON_VERSION = "3.8"
174+ MODEL_FILE_NAME = "fake_model_name"
175+ FAKE_MD_URL = "http://<model-deployment-url>"
176+
177+
178+ def _prepare (model ):
179+ model .prepare (
180+ inference_conda_env = INFERENCE_CONDA_ENV ,
181+ inference_python_version = DEFAULT_PYTHON_VERSION ,
182+ training_conda_env = TRAINING_CONDA_ENV ,
183+ training_python_version = DEFAULT_PYTHON_VERSION ,
184+ model_file_name = MODEL_FILE_NAME ,
185+ force_overwrite = True ,
186+ )
173187
174188
175189class TestEstimator :
@@ -315,14 +329,7 @@ def test_prepare_with_custom_scorepy(self, mock_signer):
315329 @patch ("ads.common.auth.default_signer" )
316330 def test_verify_without_reload (self , mock_signer ):
317331 """Test verify input data without reload artifacts."""
318- self .generic_model .prepare (
319- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
320- inference_python_version = "3.6" ,
321- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
322- training_python_version = "3.7" ,
323- model_file_name = "fake_model_name" ,
324- force_overwrite = True ,
325- )
332+ _prepare (self .generic_model )
326333 self .generic_model .verify (self .X_test .tolist ())
327334
328335 with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
@@ -332,20 +339,10 @@ def test_verify_without_reload(self, mock_signer):
332339 @patch ("ads.common.auth.default_signer" )
333340 def test_verify (self , mock_signer ):
334341 """Test verify input data"""
335- self .generic_model .prepare (
336- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
337- inference_python_version = "3.6" ,
338- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
339- training_python_version = "3.7" ,
340- model_file_name = "fake_model_name" ,
341- force_overwrite = True ,
342- )
342+ _prepare (self .generic_model )
343343 prediction_1 = self .generic_model .verify (self .X_test .tolist ())
344344 assert isinstance (prediction_1 , dict ), "Failed to verify json payload."
345345
346- prediction_2 = self .generic_model .verify (self .X_test .tolist ())
347- assert isinstance (prediction_2 , dict ), "Failed to verify input data."
348-
349346 def test_reload (self ):
350347 """test the reload."""
351348 pass
@@ -637,11 +634,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
637634 == random_name [:- 9 ]
638635 )
639636
637+ @pytest .mark .parametrize ("input_data" , [(X_test .tolist ())])
638+ @patch ("ads.common.auth.default_signer" )
639+ def test_predict_locally (self , mock_signer , input_data ):
640+ _prepare (self .generic_model )
641+ test_result = self .generic_model .predict (data = input_data , local = True )
642+ expected_result = self .generic_model .estimator .predict (input_data ).tolist ()
643+ assert (
644+ test_result ["prediction" ] == expected_result
645+ ), "Failed to verify input data."
646+
647+ with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
648+ self .generic_model .predict (
649+ data = input_data , local = True , reload_artifacts = False
650+ )
651+ mock_reload .assert_not_called ()
652+
640653 @patch .object (ModelDeployment , "predict" )
641654 @patch ("ads.common.auth.default_signer" )
642655 @patch ("ads.common.oci_client.OCIClientFactory" )
656+ @patch (
657+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
658+ return_value = FAKE_MD_URL ,
659+ )
643660 def test_predict_with_not_active_deployment_fail (
644- self , mock_client , mock_signer , mock_predict
661+ self , mock_url , mock_client , mock_signer , mock_predict
645662 ):
646663 """Ensures predict model fails in case of model deployment is not in an active state."""
647664 with pytest .raises (NotActiveDeploymentError ):
@@ -661,7 +678,11 @@ def test_predict_with_not_active_deployment_fail(
661678
662679 @patch ("ads.common.auth.default_signer" )
663680 @patch ("ads.common.oci_client.OCIClientFactory" )
664- def test_predict_bytes_success (self , mock_client , mock_signer ):
681+ @patch (
682+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
683+ return_value = FAKE_MD_URL ,
684+ )
685+ def test_predict_bytes_success (self , mock_url , mock_client , mock_signer ):
665686 """Ensures predict model passes with bytes input."""
666687 with patch .object (
667688 ModelDeployment , "state" , new_callable = PropertyMock
@@ -670,7 +691,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
670691 with patch .object (ModelDeployment , "predict" ) as mock_predict :
671692 mock_predict .return_value = {"result" : "result" }
672693 self .generic_model .model_deployment = ModelDeployment (
673- model_deployment_id = "test"
694+ model_deployment_id = "test" ,
674695 )
675696 # self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
676697 self .generic_model ._as_onnx = False
@@ -683,7 +704,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
683704
684705 @patch ("ads.common.auth.default_signer" )
685706 @patch ("ads.common.oci_client.OCIClientFactory" )
686- def test_predict_success (self , mock_client , mock_signer ):
707+ @patch (
708+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
709+ return_value = FAKE_MD_URL ,
710+ )
711+ def test_predict_success (self , mock_url , mock_client , mock_signer ):
687712 """Ensures predict model passes with valid input parameters."""
688713 with patch .object (
689714 ModelDeployment , "state" , new_callable = PropertyMock
@@ -800,7 +825,11 @@ def test_from_model_artifact(
800825
801826 @patch ("ads.common.auth.default_signer" )
802827 @patch ("ads.common.oci_client.OCIClientFactory" )
803- def test_predict_success__serialize_input (self , mock_client , mock_signer ):
828+ @patch (
829+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
830+ return_value = FAKE_MD_URL ,
831+ )
832+ def test_predict_success__serialize_input (self , mock_url , mock_client , mock_signer ):
804833 """Ensures predict model passes with valid input parameters."""
805834
806835 df = pd .DataFrame ([1 , 2 , 3 ])
0 commit comments