@@ -1280,6 +1280,7 @@ def test_from_id_fail(self):
12801280
12811281 @patch .object (GenericModel , "from_model_catalog" )
12821282 def test_from_id_model_without_artifact (self , mock_from_model_catalog ):
1283+ """Test to check model artifact is not loaded when load_artifact is set to False"""
12831284 test_model_id = "xxxx.datasciencemodel.xxxx"
12841285 mock_model = MagicMock (model_id = test_model_id , model_artifact = None )
12851286 mock_from_model_catalog .return_value = mock_model
@@ -1307,6 +1308,7 @@ def test_from_id_model_without_artifact(self, mock_from_model_catalog):
13071308
13081309 @patch .object (GenericModel , "from_model_catalog" )
13091310 def test_from_id_with_artifact (self , mock_from_model_catalog ):
1311+ """Test to check model artifact is loaded when load_artifact is set to True"""
13101312 test_model_id = "xxxx.datasciencemodel.xxxx"
13111313 artifact_dir = "test_dir"
13121314 model_artifact = MagicMock (artifact_dir = artifact_dir , reload = False )
@@ -1359,6 +1361,42 @@ def test_download_artifact_fail(self):
13591361 generic_model = GenericModel ()
13601362 generic_model .download_artifact (uri = "" , auth = {"config" : "value" })
13611363
1364+ @patch .object (ModelArtifact , "from_uri" )
1365+ @patch .object (DataScienceModel , "from_id" )
1366+ @patch .object (GenericModel , "reload_runtime_info" )
1367+ @patch .object (DataScienceModel , "download_artifact" )
1368+ def test_download_artifact (
1369+ self ,
1370+ mock_download_artifact ,
1371+ mock_reload_runtime_info ,
1372+ mock_dsc_model_from_id ,
1373+ mock_from_uri ,
1374+ ):
1375+ """Test to check if model artifacts are updated after download_artifact is called"""
1376+ test_model_id = "xxxx.datasciencemodel.xxxx"
1377+ artifact_dir = "test_dir"
1378+ self .generic_model .dsc_model = MagicMock (id = test_model_id )
1379+ self .generic_model .model_artifact = None
1380+ self .generic_model .artifact_dir = artifact_dir
1381+
1382+ mock_dsc_model_from_id .return_value = MagicMock (id = test_model_id )
1383+ mock_download_artifact .return_value = None
1384+ mock_artifact_instance = MagicMock (model = "test_model" )
1385+ mock_from_uri .return_value = mock_artifact_instance
1386+
1387+ assert self .generic_model .model_artifact is None
1388+
1389+ self .generic_model .download_artifact (
1390+ artifact_dir = artifact_dir ,
1391+ auth = {"config" : {}},
1392+ force_overwrite = True ,
1393+ bucket_uri = "bucket_uri" ,
1394+ remove_existing_artifact = True ,
1395+ )
1396+
1397+ mock_reload_runtime_info .assert_called ()
1398+ assert self .generic_model .model_artifact is not None
1399+
13621400 def test_save_without_local_artifact (self ):
13631401 """Test to check if model artifact is available before saving the model"""
13641402 self .generic_model .model_artifact = None
0 commit comments