168168 "training_script" : None ,
169169}
170170
171-
171+ INFERENCE_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1"
172+ TRAINING_CONDA_ENV = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1"
173+ DEFAULT_PYTHON_VERSION = "3.8"
174+ MODEL_FILE_NAME = "fake_model_name"
175+
176+ def _prepare (model ):
177+ model .prepare (
178+ inference_conda_env = INFERENCE_CONDA_ENV ,
179+ inference_python_version = DEFAULT_PYTHON_VERSION ,
180+ training_conda_env = TRAINING_CONDA_ENV ,
181+ training_python_version = DEFAULT_PYTHON_VERSION ,
182+ model_file_name = MODEL_FILE_NAME ,
183+ force_overwrite = True ,
184+ )
172185class TestEstimator :
173186 def predict (self , x ):
174187 return x ** 2
@@ -300,14 +313,7 @@ def test_prepare_both_conda_env(self, mock_signer):
300313 @patch ("ads.common.auth.default_signer" )
301314 def test_verify_without_reload (self , mock_signer ):
302315 """Test verify input data without reload artifacts."""
303- self .generic_model .prepare (
304- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
305- inference_python_version = "3.6" ,
306- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
307- training_python_version = "3.7" ,
308- model_file_name = "fake_model_name" ,
309- force_overwrite = True ,
310- )
316+ _prepare (self .generic_model )
311317 self .generic_model .verify (self .X_test .tolist ())
312318
313319 with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
@@ -317,20 +323,10 @@ def test_verify_without_reload(self, mock_signer):
317323 @patch ("ads.common.auth.default_signer" )
318324 def test_verify (self , mock_signer ):
319325 """Test verify input data"""
320- self .generic_model .prepare (
321- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
322- inference_python_version = "3.6" ,
323- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
324- training_python_version = "3.7" ,
325- model_file_name = "fake_model_name" ,
326- force_overwrite = True ,
327- )
326+ _prepare (self .generic_model )
328327 prediction_1 = self .generic_model .verify (self .X_test .tolist ())
329328 assert isinstance (prediction_1 , dict ), "Failed to verify json payload."
330329
331- prediction_2 = self .generic_model .verify (self .X_test .tolist ())
332- assert isinstance (prediction_2 , dict ), "Failed to verify input data."
333-
334330 def test_reload (self ):
335331 """test the reload."""
336332 pass
@@ -622,6 +618,21 @@ def test_deploy_with_default_display_name(self, mock_deploy):
622618 == random_name [:- 9 ]
623619 )
624620
621+ @pytest .mark .parametrize (
622+ "input_data" ,
623+ [(X_test .tolist ())]
624+ )
625+ @patch ("ads.common.auth.default_signer" )
626+ def test_predict_locally (self , mock_signer , input_data ):
627+ _prepare (self .generic_model )
628+ test_result = self .generic_model .predict (data = input_data , local = True )
629+ expected_result = self .generic_model .estimator .predict (input_data ).tolist ()
630+ assert test_result ['prediction' ] == expected_result , "Failed to verify input data."
631+
632+ with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
633+ self .generic_model .predict (data = input_data , local = True , reload_artifacts = False )
634+ mock_reload .assert_not_called ()
635+
625636 @patch .object (ModelDeployment , "predict" )
626637 @patch ("ads.common.auth.default_signer" )
627638 @patch ("ads.common.oci_client.OCIClientFactory" )
0 commit comments