|
24 | 24 | MODEL_DATA = "s3://bucket/model.tar.gz" |
25 | 25 | MODEL_IMAGE = "mi" |
26 | 26 | ENTRY_POINT = "blah.py" |
27 | | -INSTANCE_TYPE = "p2.xlarge" |
28 | 27 | ROLE = "some-role" |
29 | 28 |
|
30 | 29 | DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") |
@@ -172,145 +171,6 @@ def test_create_no_defaults(sagemaker_session, tmpdir): |
172 | 171 | } |
173 | 172 |
|
174 | 173 |
|
175 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
176 | | -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) |
177 | | -def test_deploy(sagemaker_session, tmpdir): |
178 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
179 | | - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1) |
180 | | - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
181 | | - name=MODEL_NAME, |
182 | | - production_variants=[ |
183 | | - { |
184 | | - "InitialVariantWeight": 1, |
185 | | - "ModelName": MODEL_NAME, |
186 | | - "InstanceType": INSTANCE_TYPE, |
187 | | - "InitialInstanceCount": 1, |
188 | | - "VariantName": "AllTraffic", |
189 | | - } |
190 | | - ], |
191 | | - tags=None, |
192 | | - kms_key=None, |
193 | | - wait=True, |
194 | | - data_capture_config_dict=None, |
195 | | - ) |
196 | | - |
197 | | - |
198 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
199 | | -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) |
200 | | -def test_deploy_endpoint_name(sagemaker_session, tmpdir): |
201 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
202 | | - model.deploy(endpoint_name="blah", instance_type=INSTANCE_TYPE, initial_instance_count=55) |
203 | | - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
204 | | - name="blah", |
205 | | - production_variants=[ |
206 | | - { |
207 | | - "InitialVariantWeight": 1, |
208 | | - "ModelName": MODEL_NAME, |
209 | | - "InstanceType": INSTANCE_TYPE, |
210 | | - "InitialInstanceCount": 55, |
211 | | - "VariantName": "AllTraffic", |
212 | | - } |
213 | | - ], |
214 | | - tags=None, |
215 | | - kms_key=None, |
216 | | - wait=True, |
217 | | - data_capture_config_dict=None, |
218 | | - ) |
219 | | - |
220 | | - |
221 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
222 | | -@patch("time.strftime", MagicMock(return_value=TIMESTAMP)) |
223 | | -def test_deploy_tags(sagemaker_session, tmpdir): |
224 | | - model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
225 | | - tags = [{"ModelName": "TestModel"}] |
226 | | - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags) |
227 | | - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
228 | | - name=MODEL_NAME, |
229 | | - production_variants=[ |
230 | | - { |
231 | | - "InitialVariantWeight": 1, |
232 | | - "ModelName": MODEL_NAME, |
233 | | - "InstanceType": INSTANCE_TYPE, |
234 | | - "InitialInstanceCount": 1, |
235 | | - "VariantName": "AllTraffic", |
236 | | - } |
237 | | - ], |
238 | | - tags=tags, |
239 | | - kms_key=None, |
240 | | - wait=True, |
241 | | - data_capture_config_dict=None, |
242 | | - ) |
243 | | - |
244 | | - |
245 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
246 | | -@patch("tarfile.open") |
247 | | -@patch("time.strftime", return_value=TIMESTAMP) |
248 | | -def test_deploy_accelerator_type(tfo, time, sagemaker_session): |
249 | | - model = DummyFrameworkModel(sagemaker_session) |
250 | | - model.deploy( |
251 | | - instance_type=INSTANCE_TYPE, initial_instance_count=1, accelerator_type=ACCELERATOR_TYPE |
252 | | - ) |
253 | | - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
254 | | - name=MODEL_NAME, |
255 | | - production_variants=[ |
256 | | - { |
257 | | - "InitialVariantWeight": 1, |
258 | | - "ModelName": MODEL_NAME, |
259 | | - "InstanceType": INSTANCE_TYPE, |
260 | | - "InitialInstanceCount": 1, |
261 | | - "VariantName": "AllTraffic", |
262 | | - "AcceleratorType": ACCELERATOR_TYPE, |
263 | | - } |
264 | | - ], |
265 | | - tags=None, |
266 | | - kms_key=None, |
267 | | - wait=True, |
268 | | - data_capture_config_dict=None, |
269 | | - ) |
270 | | - |
271 | | - |
272 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
273 | | -@patch("tarfile.open") |
274 | | -@patch("time.strftime", return_value=TIMESTAMP) |
275 | | -def test_deploy_kms_key(tfo, time, sagemaker_session): |
276 | | - key = "some-key-arn" |
277 | | - model = DummyFrameworkModel(sagemaker_session) |
278 | | - model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, kms_key=key) |
279 | | - sagemaker_session.endpoint_from_production_variants.assert_called_with( |
280 | | - name=MODEL_NAME, |
281 | | - production_variants=[ |
282 | | - { |
283 | | - "InitialVariantWeight": 1, |
284 | | - "ModelName": MODEL_NAME, |
285 | | - "InstanceType": INSTANCE_TYPE, |
286 | | - "InitialInstanceCount": 1, |
287 | | - "VariantName": "AllTraffic", |
288 | | - } |
289 | | - ], |
290 | | - tags=None, |
291 | | - kms_key=key, |
292 | | - wait=True, |
293 | | - data_capture_config_dict=None, |
294 | | - ) |
295 | | - |
296 | | - |
297 | | -@patch("sagemaker.session.Session") |
298 | | -@patch("sagemaker.local.LocalSession") |
299 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
300 | | -def test_deploy_creates_correct_session(local_session, session, tmpdir): |
301 | | - # We expect a LocalSession when deploying to instance_type = 'local' |
302 | | - model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) |
303 | | - model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1) |
304 | | - assert model.sagemaker_session == local_session.return_value |
305 | | - |
306 | | - # We expect a real Session when deploying to instance_type != local/local_gpu |
307 | | - model = DummyFrameworkModel(sagemaker_session=None, source_dir=str(tmpdir)) |
308 | | - model.deploy( |
309 | | - endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2 |
310 | | - ) |
311 | | - assert model.sagemaker_session == session.return_value |
312 | | - |
313 | | - |
314 | 174 | @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
315 | 175 | def test_deploy_update_endpoint(sagemaker_session, tmpdir): |
316 | 176 | model = DummyFrameworkModel(sagemaker_session, source_dir=tmpdir) |
|
0 commit comments