168168 "training_script" : None ,
169169}
170170
171+ INFERENCE_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
172+ TRAINING_CONDA_ENV = "oci://bucket@namespace/<path_to_service_pack>"
173+ DEFAULT_PYTHON_VERSION = "3.8"
174+ MODEL_FILE_NAME = "fake_model_name"
175+ FAKE_MD_URL = "http://<model-deployment-url>"
176+
177+
178+ def _prepare (model ):
179+ model .prepare (
180+ inference_conda_env = INFERENCE_CONDA_ENV ,
181+ inference_python_version = DEFAULT_PYTHON_VERSION ,
182+ training_conda_env = TRAINING_CONDA_ENV ,
183+ training_python_version = DEFAULT_PYTHON_VERSION ,
184+ model_file_name = MODEL_FILE_NAME ,
185+ force_overwrite = True ,
186+ )
187+
171188
172189class TestEstimator :
173190 def predict (self , x ):
174191 return x ** 2
175192
176193
177194class TestGenericModel :
178-
179195 iris = load_iris ()
180196 X , y = iris .data , iris .target
181197 X_train , X_test , y_train , y_test = train_test_split (X , y )
@@ -298,16 +314,22 @@ def test_prepare_both_conda_env(self, mock_signer):
298314 )
299315
300316 @patch ("ads.common.auth.default_signer" )
301- def test_verify_without_reload (self , mock_signer ):
302- """Test verify input data without reload artifacts ."""
317+ def test_prepare_with_custom_scorepy (self , mock_signer ):
318+ """Test prepare a trained model with custom score.py ."""
303319 self .generic_model .prepare (
304- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
305- inference_python_version = "3.6" ,
306- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
307- training_python_version = "3.7" ,
320+ INFERENCE_CONDA_ENV ,
308321 model_file_name = "fake_model_name" ,
309- force_overwrite = True ,
322+ score_py_uri = f" { os . path . dirname ( os . path . abspath ( __file__ )) } /test_files/custom_score.py" ,
310323 )
324+ assert os .path .exists (os .path .join ("fake_folder" , "score.py" ))
325+
326+ prediction = self .generic_model .verify (data = "test" )["prediction" ]
327+ assert prediction == "This is a custom score.py."
328+
329+ @patch ("ads.common.auth.default_signer" )
330+ def test_verify_without_reload (self , mock_signer ):
331+ """Test verify input data without reload artifacts."""
332+ _prepare (self .generic_model )
311333 self .generic_model .verify (self .X_test .tolist ())
312334
313335 with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
@@ -317,20 +339,10 @@ def test_verify_without_reload(self, mock_signer):
317339 @patch ("ads.common.auth.default_signer" )
318340 def test_verify (self , mock_signer ):
319341 """Test verify input data"""
320- self .generic_model .prepare (
321- inference_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/General_Machine_Learning_for_CPUs/1.0/mlcpuv1" ,
322- inference_python_version = "3.6" ,
323- training_conda_env = "oci://service-conda-packs@ociodscdev/service_pack/cpu/Oracle_Database_for_CPU_Python_3.7/1.0/database_p37_cpu_v1" ,
324- training_python_version = "3.7" ,
325- model_file_name = "fake_model_name" ,
326- force_overwrite = True ,
327- )
342+ _prepare (self .generic_model )
328343 prediction_1 = self .generic_model .verify (self .X_test .tolist ())
329344 assert isinstance (prediction_1 , dict ), "Failed to verify json payload."
330345
331- prediction_2 = self .generic_model .verify (self .X_test .tolist ())
332- assert isinstance (prediction_2 , dict ), "Failed to verify input data."
333-
334346 def test_reload (self ):
335347 """test the reload."""
336348 pass
@@ -622,11 +634,31 @@ def test_deploy_with_default_display_name(self, mock_deploy):
622634 == random_name [:- 9 ]
623635 )
624636
637+ @pytest .mark .parametrize ("input_data" , [(X_test .tolist ())])
638+ @patch ("ads.common.auth.default_signer" )
639+ def test_predict_locally (self , mock_signer , input_data ):
640+ _prepare (self .generic_model )
641+ test_result = self .generic_model .predict (data = input_data , local = True )
642+ expected_result = self .generic_model .estimator .predict (input_data ).tolist ()
643+ assert (
644+ test_result ["prediction" ] == expected_result
645+ ), "Failed to verify input data."
646+
647+ with patch ("ads.model.artifact.ModelArtifact.reload" ) as mock_reload :
648+ self .generic_model .predict (
649+ data = input_data , local = True , reload_artifacts = False
650+ )
651+ mock_reload .assert_not_called ()
652+
625653 @patch .object (ModelDeployment , "predict" )
626654 @patch ("ads.common.auth.default_signer" )
627655 @patch ("ads.common.oci_client.OCIClientFactory" )
656+ @patch (
657+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
658+ return_value = FAKE_MD_URL ,
659+ )
628660 def test_predict_with_not_active_deployment_fail (
629- self , mock_client , mock_signer , mock_predict
661+ self , mock_url , mock_client , mock_signer , mock_predict
630662 ):
631663 """Ensures predict model fails in case of model deployment is not in an active state."""
632664 with pytest .raises (NotActiveDeploymentError ):
@@ -646,7 +678,11 @@ def test_predict_with_not_active_deployment_fail(
646678
647679 @patch ("ads.common.auth.default_signer" )
648680 @patch ("ads.common.oci_client.OCIClientFactory" )
649- def test_predict_bytes_success (self , mock_client , mock_signer ):
681+ @patch (
682+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
683+ return_value = FAKE_MD_URL ,
684+ )
685+ def test_predict_bytes_success (self , mock_url , mock_client , mock_signer ):
650686 """Ensures predict model passes with bytes input."""
651687 with patch .object (
652688 ModelDeployment , "state" , new_callable = PropertyMock
@@ -655,7 +691,7 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
655691 with patch .object (ModelDeployment , "predict" ) as mock_predict :
656692 mock_predict .return_value = {"result" : "result" }
657693 self .generic_model .model_deployment = ModelDeployment (
658- model_deployment_id = "test"
694+ model_deployment_id = "test" ,
659695 )
660696 # self.generic_model.model_deployment.current_state = ModelDeploymentState.ACTIVE
661697 self .generic_model ._as_onnx = False
@@ -668,7 +704,11 @@ def test_predict_bytes_success(self, mock_client, mock_signer):
668704
669705 @patch ("ads.common.auth.default_signer" )
670706 @patch ("ads.common.oci_client.OCIClientFactory" )
671- def test_predict_success (self , mock_client , mock_signer ):
707+ @patch (
708+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
709+ return_value = FAKE_MD_URL ,
710+ )
711+ def test_predict_success (self , mock_url , mock_client , mock_signer ):
672712 """Ensures predict model passes with valid input parameters."""
673713 with patch .object (
674714 ModelDeployment , "state" , new_callable = PropertyMock
@@ -785,7 +825,11 @@ def test_from_model_artifact(
785825
786826 @patch ("ads.common.auth.default_signer" )
787827 @patch ("ads.common.oci_client.OCIClientFactory" )
788- def test_predict_success__serialize_input (self , mock_client , mock_signer ):
828+ @patch (
829+ "ads.model.deployment.model_deployment.ModelDeployment.url" ,
830+ return_value = FAKE_MD_URL ,
831+ )
832+ def test_predict_success__serialize_input (self , mock_url , mock_client , mock_signer ):
789833 """Ensures predict model passes with valid input parameters."""
790834
791835 df = pd .DataFrame ([1 , 2 , 3 ])
@@ -795,7 +839,6 @@ def test_predict_success__serialize_input(self, mock_client, mock_signer):
795839 with patch .object (
796840 GenericModel , "get_data_serializer"
797841 ) as mock_get_data_serializer :
798-
799842 mock_get_data_serializer .return_value .data = df .to_json ()
800843 mock_state .return_value = ModelDeploymentState .ACTIVE
801844 with patch .object (ModelDeployment , "predict" ) as mock_predict :
@@ -1782,7 +1825,6 @@ def test_upload_artifact_fail(self):
17821825 def test_upload_artifact_success (self ):
17831826 """Tests uploading model artifacts to the provided `uri`."""
17841827 with tempfile .TemporaryDirectory () as tmp_dir :
1785-
17861828 # copy test artifacts to the temp folder
17871829 shutil .copytree (
17881830 os .path .join (self .curr_dir , "test_files/valid_model_artifacts" ),
0 commit comments