|
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | from __future__ import absolute_import |
14 | 14 |
|
15 | | -import copy |
16 | 15 | import os |
17 | 16 | import subprocess |
18 | 17 |
|
19 | | -import sagemaker |
20 | | -from sagemaker.model import FrameworkModel, Model, ModelPackage |
| 18 | +from sagemaker.model import FrameworkModel |
21 | 19 | from sagemaker.predictor import RealTimePredictor |
22 | 20 |
|
23 | 21 | import pytest |
|
53 | 51 | CODECOMMIT_BRANCH = "master" |
54 | 52 | REPO_DIR = "/tmp/repo_dir" |
55 | 53 |
|
56 | | - |
57 | | -DESCRIBE_MODEL_PACKAGE_RESPONSE = { |
58 | | - "InferenceSpecification": { |
59 | | - "SupportedResponseMIMETypes": ["text"], |
60 | | - "SupportedContentTypes": ["text/csv"], |
61 | | - "SupportedTransformInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], |
62 | | - "Containers": [ |
63 | | - { |
64 | | - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/decision-trees-sample:latest", |
65 | | - "ImageDigest": "sha256:1234556789", |
66 | | - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
67 | | - } |
68 | | - ], |
69 | | - "SupportedRealtimeInferenceInstanceTypes": ["ml.m4.xlarge", "ml.m4.2xlarge"], |
70 | | - }, |
71 | | - "ModelPackageDescription": "Model Package created from training with " |
72 | | - "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", |
73 | | - "CreationTime": 1542752036.687, |
74 | | - "ModelPackageArn": "arn:aws:sagemaker:us-east-2:123:model-package/mp-scikit-decision-trees", |
75 | | - "ModelPackageStatusDetails": {"ValidationStatuses": [], "ImageScanStatuses": []}, |
76 | | - "SourceAlgorithmSpecification": { |
77 | | - "SourceAlgorithms": [ |
78 | | - { |
79 | | - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
80 | | - "AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees", |
81 | | - } |
82 | | - ] |
83 | | - }, |
84 | | - "ModelPackageStatus": "Completed", |
85 | | - "ModelPackageName": "mp-scikit-decision-trees-1542410022-2018-11-20-22-13-56-502", |
86 | | - "CertifyForMarketplace": False, |
87 | | -} |
88 | | - |
89 | 54 | DESCRIBE_COMPILATION_JOB_RESPONSE = { |
90 | 55 | "CompilationJobStatus": "Completed", |
91 | 56 | "ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"}, |
@@ -417,181 +382,6 @@ def test_model_enable_network_isolation(sagemaker_session): |
417 | 382 | assert model.enable_network_isolation() is False |
418 | 383 |
|
419 | 384 |
|
420 | | -@patch("sagemaker.model.Model._create_sagemaker_model") |
421 | | -def test_model_create_transformer(create_sagemaker_model, sagemaker_session): |
422 | | - model_name = "auto-generated-model" |
423 | | - model = Model(MODEL_DATA, MODEL_IMAGE, name=model_name, sagemaker_session=sagemaker_session) |
424 | | - |
425 | | - instance_type = "ml.m4.xlarge" |
426 | | - transformer = model.transformer(instance_count=1, instance_type=instance_type) |
427 | | - |
428 | | - create_sagemaker_model.assert_called_with(instance_type, tags=None) |
429 | | - |
430 | | - assert isinstance(transformer, sagemaker.transformer.Transformer) |
431 | | - assert transformer.model_name == model_name |
432 | | - assert transformer.instance_type == instance_type |
433 | | - assert transformer.instance_count == 1 |
434 | | - assert transformer.sagemaker_session == sagemaker_session |
435 | | - assert transformer.base_transform_job_name == model_name |
436 | | - |
437 | | - assert transformer.strategy is None |
438 | | - assert transformer.env is None |
439 | | - assert transformer.output_path is None |
440 | | - assert transformer.output_kms_key is None |
441 | | - assert transformer.accept is None |
442 | | - assert transformer.assemble_with is None |
443 | | - assert transformer.volume_kms_key is None |
444 | | - assert transformer.max_concurrent_transforms is None |
445 | | - assert transformer.max_payload is None |
446 | | - assert transformer.tags is None |
447 | | - |
448 | | - |
449 | | -@patch("sagemaker.model.Model._create_sagemaker_model") |
450 | | -def test_model_create_transformer_optional_params(create_sagemaker_model, sagemaker_session): |
451 | | - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session) |
452 | | - |
453 | | - instance_type = "ml.m4.xlarge" |
454 | | - strategy = "MultiRecord" |
455 | | - assemble_with = "Line" |
456 | | - output_path = "s3://bucket/path" |
457 | | - kms_key = "key" |
458 | | - accept = "text/csv" |
459 | | - env = {"test": True} |
460 | | - max_concurrent_transforms = 1 |
461 | | - max_payload = 6 |
462 | | - tags = [{"Key": "k", "Value": "v"}] |
463 | | - |
464 | | - transformer = model.transformer( |
465 | | - instance_count=1, |
466 | | - instance_type=instance_type, |
467 | | - strategy=strategy, |
468 | | - assemble_with=assemble_with, |
469 | | - output_path=output_path, |
470 | | - output_kms_key=kms_key, |
471 | | - accept=accept, |
472 | | - env=env, |
473 | | - max_concurrent_transforms=max_concurrent_transforms, |
474 | | - max_payload=max_payload, |
475 | | - tags=tags, |
476 | | - volume_kms_key=kms_key, |
477 | | - ) |
478 | | - |
479 | | - create_sagemaker_model.assert_called_with(instance_type, tags=tags) |
480 | | - |
481 | | - assert isinstance(transformer, sagemaker.transformer.Transformer) |
482 | | - assert transformer.strategy == strategy |
483 | | - assert transformer.assemble_with == assemble_with |
484 | | - assert transformer.output_path == output_path |
485 | | - assert transformer.output_kms_key == kms_key |
486 | | - assert transformer.accept == accept |
487 | | - assert transformer.max_concurrent_transforms == max_concurrent_transforms |
488 | | - assert transformer.max_payload == max_payload |
489 | | - assert transformer.env == env |
490 | | - assert transformer.tags == tags |
491 | | - assert transformer.volume_kms_key == kms_key |
492 | | - |
493 | | - |
494 | | -@patch("sagemaker.model.Model._create_sagemaker_model") |
495 | | -def test_model_create_transformer_network_isolation(create_sagemaker_model, sagemaker_session): |
496 | | - model = Model( |
497 | | - MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session, enable_network_isolation=True |
498 | | - ) |
499 | | - |
500 | | - transformer = model.transformer(1, "ml.m4.xlarge", env={"should_be": "overwritten"}) |
501 | | - assert transformer.env is None |
502 | | - |
503 | | - |
504 | | -@patch("sagemaker.session.Session") |
505 | | -@patch("sagemaker.local.LocalSession") |
506 | | -@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
507 | | -def test_transformer_creates_correct_session(local_session, session): |
508 | | - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) |
509 | | - transformer = model.transformer(instance_count=1, instance_type="local") |
510 | | - assert model.sagemaker_session == local_session.return_value |
511 | | - assert transformer.sagemaker_session == local_session.return_value |
512 | | - |
513 | | - model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=None) |
514 | | - transformer = model.transformer(instance_count=1, instance_type="ml.m5.xlarge") |
515 | | - assert model.sagemaker_session == session.return_value |
516 | | - assert transformer.sagemaker_session == session.return_value |
517 | | - |
518 | | - |
519 | | -def test_model_package_enable_network_isolation_with_no_product_id(sagemaker_session): |
520 | | - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
521 | | - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE |
522 | | - ) |
523 | | - |
524 | | - model_package = ModelPackage( |
525 | | - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
526 | | - ) |
527 | | - assert model_package.enable_network_isolation() is False |
528 | | - |
529 | | - |
530 | | -def test_model_package_enable_network_isolation_with_product_id(sagemaker_session): |
531 | | - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) |
532 | | - model_package_response["InferenceSpecification"]["Containers"].append( |
533 | | - { |
534 | | - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", |
535 | | - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
536 | | - "ProductId": "some-product-id", |
537 | | - } |
538 | | - ) |
539 | | - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
540 | | - return_value=model_package_response |
541 | | - ) |
542 | | - |
543 | | - model_package = ModelPackage( |
544 | | - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
545 | | - ) |
546 | | - assert model_package.enable_network_isolation() is True |
547 | | - |
548 | | - |
549 | | -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) |
550 | | -def test_model_package_create_transformer(sagemaker_session): |
551 | | - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
552 | | - return_value=DESCRIBE_MODEL_PACKAGE_RESPONSE |
553 | | - ) |
554 | | - |
555 | | - model_package = ModelPackage( |
556 | | - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
557 | | - ) |
558 | | - model_package.name = "auto-generated-model" |
559 | | - transformer = model_package.transformer( |
560 | | - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} |
561 | | - ) |
562 | | - assert isinstance(transformer, sagemaker.transformer.Transformer) |
563 | | - assert transformer.model_name == "auto-generated-model" |
564 | | - assert transformer.instance_type == "ml.m4.xlarge" |
565 | | - assert transformer.env == {"test": True} |
566 | | - |
567 | | - |
568 | | -@patch("sagemaker.model.ModelPackage._create_sagemaker_model", Mock()) |
569 | | -def test_model_package_create_transformer_with_product_id(sagemaker_session): |
570 | | - model_package_response = copy.deepcopy(DESCRIBE_MODEL_PACKAGE_RESPONSE) |
571 | | - model_package_response["InferenceSpecification"]["Containers"].append( |
572 | | - { |
573 | | - "Image": "1.dkr.ecr.us-east-2.amazonaws.com/some-container:latest", |
574 | | - "ModelDataUrl": "s3://bucket/output/model.tar.gz", |
575 | | - "ProductId": "some-product-id", |
576 | | - } |
577 | | - ) |
578 | | - sagemaker_session.sagemaker_client.describe_model_package = Mock( |
579 | | - return_value=model_package_response |
580 | | - ) |
581 | | - |
582 | | - model_package = ModelPackage( |
583 | | - role="role", model_package_arn="my-model-package", sagemaker_session=sagemaker_session |
584 | | - ) |
585 | | - model_package.name = "auto-generated-model" |
586 | | - transformer = model_package.transformer( |
587 | | - instance_count=1, instance_type="ml.m4.xlarge", env={"test": True} |
588 | | - ) |
589 | | - assert isinstance(transformer, sagemaker.transformer.Transformer) |
590 | | - assert transformer.model_name == "auto-generated-model" |
591 | | - assert transformer.instance_type == "ml.m4.xlarge" |
592 | | - assert transformer.env is None |
593 | | - |
594 | | - |
595 | 385 | @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) |
596 | 386 | @patch("time.strftime", MagicMock(return_value=TIMESTAMP)) |
597 | 387 | def test_model_delete_model(sagemaker_session, tmpdir): |
|
0 commit comments