2626MODEL_NAME = "{}-{}" .format (MODEL_IMAGE , TIMESTAMP )
2727
2828INSTANCE_COUNT = 2
29- INSTANCE_TYPE = "c4.4xlarge"
29+ INSTANCE_TYPE = "ml. c4.4xlarge"
3030ROLE = "some-role"
3131
3232BASE_PRODUCTION_VARIANT = {
@@ -43,17 +43,119 @@ def sagemaker_session():
4343 return Mock ()
4444
4545
46- @patch ("sagemaker.production_variant" )
46+ def test_prepare_container_def ():
47+ env = {"FOO" : "BAR" }
48+ model = Model (MODEL_DATA , MODEL_IMAGE , env = env )
49+
50+ container_def = model .prepare_container_def (INSTANCE_TYPE , "ml.eia.medium" )
51+
52+ expected = {"Image" : MODEL_IMAGE , "Environment" : env , "ModelDataUrl" : MODEL_DATA }
53+ assert expected == container_def
54+
55+
4756@patch ("sagemaker.model.Model.prepare_container_def" )
4857@patch ("sagemaker.utils.name_from_image" )
49- def test_deploy (name_from_image , prepare_container_def , production_variant , sagemaker_session ):
58+ def test_create_sagemaker_model (name_from_image , prepare_container_def , sagemaker_session ):
59+ name_from_image .return_value = MODEL_NAME
60+
61+ container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
62+ prepare_container_def .return_value = container_def
63+
64+ model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = sagemaker_session )
65+ model ._create_sagemaker_model (INSTANCE_TYPE )
66+
67+ prepare_container_def .assert_called_with (INSTANCE_TYPE , accelerator_type = None )
68+ name_from_image .assert_called_with (MODEL_IMAGE )
69+
70+ sagemaker_session .create_model .assert_called_with (
71+ MODEL_NAME , None , container_def , vpc_config = None , enable_network_isolation = False , tags = None
72+ )
73+
74+
75+ @patch ("sagemaker.utils.name_from_image" , Mock ())
76+ @patch ("sagemaker.model.Model.prepare_container_def" )
77+ def test_create_sagemaker_model_accelerator_type (prepare_container_def , sagemaker_session ):
78+ model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = sagemaker_session )
79+
80+ accelerator_type = "ml.eia.medium"
81+ model ._create_sagemaker_model (INSTANCE_TYPE , accelerator_type = accelerator_type )
82+
83+ prepare_container_def .assert_called_with (INSTANCE_TYPE , accelerator_type = accelerator_type )
84+
85+
86+ @patch ("sagemaker.model.Model.prepare_container_def" )
87+ @patch ("sagemaker.utils.name_from_image" )
88+ def test_create_sagemaker_model_tags (name_from_image , prepare_container_def , sagemaker_session ):
89+ container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
90+ prepare_container_def .return_value = container_def
91+
5092 name_from_image .return_value = MODEL_NAME
5193
94+ model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = sagemaker_session )
95+
96+ tags = {"Key" : "foo" , "Value" : "bar" }
97+ model ._create_sagemaker_model (INSTANCE_TYPE , tags = tags )
98+
99+ sagemaker_session .create_model .assert_called_with (
100+ MODEL_NAME , None , container_def , vpc_config = None , enable_network_isolation = False , tags = tags
101+ )
102+
103+
104+ @patch ("sagemaker.model.Model.prepare_container_def" )
105+ @patch ("sagemaker.utils.name_from_image" )
106+ def test_create_sagemaker_model_optional_model_params (
107+ name_from_image , prepare_container_def , sagemaker_session
108+ ):
52109 container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
53110 prepare_container_def .return_value = container_def
54111
112+ vpc_config = {"Subnets" : ["123" ], "SecurityGroupIds" : ["456" , "789" ]}
113+
114+ model = Model (
115+ MODEL_DATA ,
116+ MODEL_IMAGE ,
117+ name = MODEL_NAME ,
118+ role = ROLE ,
119+ vpc_config = vpc_config ,
120+ enable_network_isolation = True ,
121+ sagemaker_session = sagemaker_session ,
122+ )
123+ model ._create_sagemaker_model (INSTANCE_TYPE )
124+
125+ name_from_image .assert_not_called ()
126+
127+ sagemaker_session .create_model .assert_called_with (
128+ MODEL_NAME ,
129+ ROLE ,
130+ container_def ,
131+ vpc_config = vpc_config ,
132+ enable_network_isolation = True ,
133+ tags = None ,
134+ )
135+
136+
137+ @patch ("sagemaker.session.Session" )
138+ @patch ("sagemaker.local.LocalSession" )
139+ def test_create_sagemaker_model_creates_correct_session (local_session , session ):
140+ model = Model (MODEL_DATA , MODEL_IMAGE )
141+ model ._create_sagemaker_model ("local" )
142+ assert model .sagemaker_session == local_session .return_value
143+
144+ model = Model (MODEL_DATA , MODEL_IMAGE )
145+ model ._create_sagemaker_model ("ml.m5.xlarge" )
146+ assert model .sagemaker_session == session .return_value
147+
148+
149+ @patch ("sagemaker.production_variant" )
150+ @patch ("sagemaker.model.Model.prepare_container_def" )
151+ @patch ("sagemaker.utils.name_from_image" )
152+ def test_deploy (name_from_image , prepare_container_def , production_variant , sagemaker_session ):
153+ name_from_image .return_value = MODEL_NAME
55154 production_variant .return_value = BASE_PRODUCTION_VARIANT
56155
156+ container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
157+ prepare_container_def .return_value = container_def
158+
57159 model = Model (MODEL_DATA , MODEL_IMAGE , role = ROLE , sagemaker_session = sagemaker_session )
58160 model .deploy (instance_type = INSTANCE_TYPE , initial_instance_count = INSTANCE_COUNT )
59161
@@ -223,7 +325,7 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):
223325
224326@patch ("sagemaker.session.Session" )
225327@patch ("sagemaker.local.LocalSession" )
226- def test_deploy_creates_correct_session (local_session , session , tmpdir ):
328+ def test_deploy_creates_correct_session (local_session , session ):
227329 # We expect a LocalSession when deploying to instance_type = 'local'
228330 model = Model (MODEL_DATA , MODEL_IMAGE , role = ROLE )
229331 model .deploy (endpoint_name = "blah" , instance_type = "local" , initial_instance_count = 1 )
@@ -356,7 +458,6 @@ def test_model_create_transformer_network_isolation(create_sagemaker_model, sage
356458
357459@patch ("sagemaker.session.Session" )
358460@patch ("sagemaker.local.LocalSession" )
359- @patch ("sagemaker.fw_utils.tar_and_upload_dir" , Mock ())
360461def test_transformer_creates_correct_session (local_session , session ):
361462 model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = None )
362463 transformer = model .transformer (instance_count = 1 , instance_type = "local" )
0 commit comments