1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
14- import numpy
14+
1515import os
1616import time
17+
18+ import numpy
1719import pytest
20+ import tests .integ
21+ from tests .integ import DATA_DIR , PYTHON_VERSION , TRAINING_DEFAULT_TIMEOUT_MINUTES
22+ from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
23+
1824from sagemaker .pytorch .estimator import PyTorch
1925from sagemaker .pytorch .model import PyTorchModel
2026from sagemaker .utils import sagemaker_timestamp
21- from tests .integ import DATA_DIR , PYTHON_VERSION , TRAINING_DEFAULT_TIMEOUT_MINUTES , REGION
22- from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
2327
2428MNIST_DIR = os .path .join (DATA_DIR , 'pytorch_mnist' )
2529MNIST_SCRIPT = os .path .join (MNIST_DIR , 'mnist.py' )
@@ -57,9 +61,11 @@ def test_deploy_model(pytorch_training_job, sagemaker_session):
5761 endpoint_name = 'test-pytorch-deploy-model-{}' .format (sagemaker_timestamp ())
5862
5963 with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
60- desc = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = pytorch_training_job )
64+ desc = sagemaker_session .sagemaker_client .describe_training_job (
65+ TrainingJobName = pytorch_training_job )
6166 model_data = desc ['ModelArtifacts' ]['S3ModelArtifacts' ]
62- model = PyTorchModel (model_data , 'SageMakerRole' , entry_point = MNIST_SCRIPT , sagemaker_session = sagemaker_session )
67+ model = PyTorchModel (model_data , 'SageMakerRole' , entry_point = MNIST_SCRIPT ,
68+ sagemaker_session = sagemaker_session )
6369 predictor = model .deploy (1 , 'ml.m4.xlarge' , endpoint_name = endpoint_name )
6470
6571 batch_size = 100
@@ -69,7 +75,7 @@ def test_deploy_model(pytorch_training_job, sagemaker_session):
6975 assert output .shape == (batch_size , 10 )
7076
7177
72- @pytest .mark .skipif (REGION in ['us-west-1' , 'eu-west-2' , 'ca-central-1' ],
78+ @pytest .mark .skipif (tests . integ . test_region () in ['us-west-1' , 'eu-west-2' , 'ca-central-1' ],
7379 reason = 'No ml.p2.xlarge supported in these regions' )
7480def test_async_fit_deploy (sagemaker_session , pytorch_full_version ):
7581 training_job_name = ""
@@ -90,7 +96,8 @@ def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
9096
9197 with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
9298 print ("Re-attaching now to: %s" % training_job_name )
93- estimator = PyTorch .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
99+ estimator = PyTorch .attach (training_job_name = training_job_name ,
100+ sagemaker_session = sagemaker_session )
94101 predictor = estimator .deploy (1 , instance_type , endpoint_name = endpoint_name )
95102
96103 batch_size = 100
@@ -105,7 +112,8 @@ def test_failed_training_job(sagemaker_session, pytorch_full_version):
105112 script_path = os .path .join (MNIST_DIR , 'failure_script.py' )
106113
107114 with timeout (minutes = TRAINING_DEFAULT_TIMEOUT_MINUTES ):
108- pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , entry_point = script_path )
115+ pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version ,
116+ entry_point = script_path )
109117
110118 with pytest .raises (ValueError ) as e :
111119 pytorch .fit ()
@@ -119,8 +127,10 @@ def _upload_training_data(pytorch):
119127
120128def _get_pytorch_estimator (sagemaker_session , pytorch_full_version , instance_type = 'ml.c4.xlarge' ,
121129 entry_point = MNIST_SCRIPT ):
122- return PyTorch (entry_point = entry_point , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
123- py_version = PYTHON_VERSION , train_instance_count = 1 , train_instance_type = instance_type ,
130+ return PyTorch (entry_point = entry_point , role = 'SageMakerRole' ,
131+ framework_version = pytorch_full_version ,
132+ py_version = PYTHON_VERSION , train_instance_count = 1 ,
133+ train_instance_type = instance_type ,
124134 sagemaker_session = sagemaker_session )
125135
126136
0 commit comments