11from ads .opctl .model .cmds import _download_model , download_model
2- import pytest
2+ import pytest
33from unittest .mock import ANY , call , patch
44from ads .model .datascience_model import DataScienceModel
55from unittest .mock import MagicMock , Mock
66from ads .opctl .model .cmds import create_signer
7+ import os
78
89
910@patch .object (DataScienceModel , "from_id" )
1011def test_model__download_model (mock_from_id ):
1112 mock_datascience_model = MagicMock ()
1213 mock_from_id .return_value = mock_datascience_model
13- _download_model ("fake_model_id" , "fake_dir" , "fake_auth" , "region" , "bucket_uri" , 36 , False )
14+ _download_model (
15+ "fake_model_id" , "fake_dir" , "fake_auth" , "region" , "bucket_uri" , 36 , False
16+ )
1417 mock_from_id .assert_called_with ("fake_model_id" )
15- mock_datascience_model .download_artifact .assert_called_with (target_dir = 'fake_dir' , force_overwrite = False , overwrite_existing_artifact = True , remove_existing_artifact = True , auth = 'fake_auth' , region = 'region' , timeout = 36 , bucket_uri = 'bucket_uri' )
18+ mock_datascience_model .download_artifact .assert_called_with (
19+ target_dir = "fake_dir" ,
20+ force_overwrite = False ,
21+ overwrite_existing_artifact = True ,
22+ remove_existing_artifact = True ,
23+ auth = "fake_auth" ,
24+ region = "region" ,
25+ timeout = 36 ,
26+ bucket_uri = "bucket_uri" ,
27+ )
1628
1729
1830@patch .object (DataScienceModel , "from_id" , side_effect = Exception ("Fake error." ))
1931def test_model__download_model_error (mock_from_id ):
2032 with pytest .raises (Exception , match = "Fake error." ):
21- _download_model ("fake_model_id" , "fake_dir" , "fake_auth" , "region" , "bucket_uri" , 36 , False )
33+ _download_model (
34+ "fake_model_id" , "fake_dir" , "fake_auth" , "region" , "bucket_uri" , 36 , False
35+ )
2236
2337
2438@patch ("ads.opctl.model.cmds._download_model" )
2539@patch ("ads.opctl.model.cmds.create_signer" )
2640def test_download_model (mock_create_signer , mock__download_model ):
2741 auth_mock = MagicMock ()
2842 mock_create_signer .return_value = auth_mock
29- download_model (ocid = "fake_model_id" )
43+ download_model (ocid = "fake_model_id" )
3044 mock_create_signer .assert_called_once ()
31- mock__download_model .assert_called_once_with (ocid = 'fake_model_id' , artifact_directory = '/Users/ziye/.ads_ops/models/fake_model_id' , region = None , bucket_uri = None , timeout = None , force_overwrite = False , oci_auth = auth_mock )
32-
45+ mock__download_model .assert_called_once_with (
46+ ocid = "fake_model_id" ,
47+ artifact_directory = os .path .expanduser ("~/.ads_ops/models/fake_model_id" ),
48+ region = None ,
49+ bucket_uri = None ,
50+ timeout = None ,
51+ force_overwrite = False ,
52+ oci_auth = auth_mock ,
53+ )
0 commit comments