2222import time
2323
2424from sagemaker .estimator import Framework
25- from sagemaker .fw_utils import framework_name_from_image , framework_version_from_tag , empty_framework_version_warning
26- from sagemaker .utils import get_config_value
27- from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
28-
25+ from sagemaker .fw_utils import framework_name_from_image , framework_version_from_tag , \
26+ empty_framework_version_warning
2927from sagemaker .tensorflow .defaults import TF_VERSION
3028from sagemaker .tensorflow .model import TensorFlowModel
29+ from sagemaker .tensorflow .serving import Model
30+ from sagemaker .utils import get_config_value
31+ from sagemaker .vpc_utils import VPC_CONFIG_DEFAULT
3132
3233logging .basicConfig ()
3334LOGGER = logging .getLogger ('sagemaker' )
@@ -103,12 +104,14 @@ def validate_requirements(self):
103104 EnvironmentError: If at least one requirement is not installed.
104105 """
105106 if not self ._cmd_exists ('tensorboard' ):
106- raise EnvironmentError ('TensorBoard is not installed in the system. Please install TensorBoard using the'
107- ' following command: \n pip install tensorboard' )
107+ raise EnvironmentError (
108+ 'TensorBoard is not installed in the system. Please install TensorBoard using the'
109+ ' following command: \n pip install tensorboard' )
108110
109111 if not self ._cmd_exists ('aws' ):
110- raise EnvironmentError ('The AWS CLI is not installed in the system. Please install the AWS CLI using the'
111- ' following command: \n pip install awscli' )
112+ raise EnvironmentError (
113+ 'The AWS CLI is not installed in the system. Please install the AWS CLI using the'
114+ ' following command: \n pip install awscli' )
112115
113116 def create_tensorboard_process (self ):
114117 """Create a TensorBoard process.
@@ -125,7 +128,8 @@ def create_tensorboard_process(self):
125128
126129 for i in range (100 ):
127130 p = subprocess .Popen (
128- ["tensorboard" , "--logdir" , self .logdir , "--host" , "localhost" , "--port" , str (port )],
131+ ["tensorboard" , "--logdir" , self .logdir , "--host" , "localhost" , "--port" ,
132+ str (port )],
129133 stdout = subprocess .PIPE ,
130134 stderr = subprocess .PIPE
131135 )
@@ -135,7 +139,8 @@ def create_tensorboard_process(self):
135139 else :
136140 return port , p
137141
138- raise OSError ('No available ports to start TensorBoard. Attempted all ports between 6006 and 6105' )
142+ raise OSError (
143+ 'No available ports to start TensorBoard. Attempted all ports between 6006 and 6105' )
139144
140145 def run (self ):
141146 """Run TensorBoard process."""
@@ -158,7 +163,8 @@ class TensorFlow(Framework):
158163
159164 __framework_name__ = 'tensorflow'
160165
161- def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None , py_version = 'py2' ,
166+ def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None ,
167+ py_version = 'py2' ,
162168 framework_version = None , requirements_file = '' , image_name = None , ** kwargs ):
163169 """Initialize an ``TensorFlow`` estimator.
164170 Args:
@@ -202,7 +208,8 @@ def _validate_requirements_file(self, requirements_file):
202208 raise ValueError ('Must specify source_dir along with a requirements file.' )
203209
204210 if os .path .isabs (requirements_file ):
205- raise ValueError ('Requirements file {} is not a path relative to source_dir.' .format (requirements_file ))
211+ raise ValueError ('Requirements file {} is not a path relative to source_dir.' .format (
212+ requirements_file ))
206213
207214 if not os .path .exists (os .path .join (self .source_dir , requirements_file )):
208215 raise ValueError ('Requirements file {} does not exist.' .format (requirements_file ))
@@ -231,6 +238,7 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_
231238 downloaded checkpoint information (default: False). This is an experimental feature, and requires
232239 TensorBoard and AWS CLI to be installed. It terminates TensorBoard when execution ends.
233240 """
241+
234242 def fit_super ():
235243 super (TensorFlow , self ).fit (inputs , wait , logs , job_name )
236244
@@ -263,7 +271,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
263271 dictionary: The transformed init_params
264272
265273 """
266- init_params = super (TensorFlow , cls )._prepare_init_params_from_job_description (job_details , model_channel_name )
274+ init_params = super (TensorFlow , cls )._prepare_init_params_from_job_description (job_details ,
275+ model_channel_name )
267276
268277 # Move some of the tensorflow specific init params from hyperparameters into the main init params.
269278 for argument in ['checkpoint_path' , 'training_steps' , 'evaluation_steps' ]:
@@ -285,15 +294,18 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
285294 # containing framework version, device type and python version (e.g. '1.5-gpu-py2').
286295 # For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
287296 # otherwise extract framework version from the tag itself.
288- init_params ['framework_version' ] = '1.4' if tag == '1.0' else framework_version_from_tag (tag )
297+ init_params ['framework_version' ] = '1.4' if tag == '1.0' else framework_version_from_tag (
298+ tag )
289299
290300 training_job_name = init_params ['base_job_name' ]
291301 if framework != cls .__framework_name__ :
292- raise ValueError ("Training job: {} didn't use image for requested framework" .format (training_job_name ))
302+ raise ValueError ("Training job: {} didn't use image for requested framework" .format (
303+ training_job_name ))
293304
294305 return init_params
295306
296- def create_model (self , model_server_workers = None , role = None , vpc_config_override = VPC_CONFIG_DEFAULT ):
307+ def create_model (self , model_server_workers = None , role = None ,
308+ vpc_config_override = VPC_CONFIG_DEFAULT , endpoint_type = None ):
297309 """Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
298310
299311 Args:
@@ -305,18 +317,44 @@ def create_model(self, model_server_workers=None, role=None, vpc_config_override
305317 Default: use subnets and security groups from this Estimator.
306318 * 'Subnets' (list[str]): List of subnet ids.
307319 * 'SecurityGroupIds' (list[str]): List of security group ids.
320+ endpoint_type: Optional. Selects the software stack used by the inference server.
321+ If not specified, the model will be configured to use the default
322+ SageMaker model server. If 'tensorflow-serving', the model will be configured to
323+ use the SageMaker Tensorflow Serving container.
308324
309325 Returns:
310326 sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
311327 See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
312328 """
313- env = { 'SAGEMAKER_REQUIREMENTS' : self . requirements_file }
329+
314330 role = role or self .role
315- return TensorFlowModel (self .model_data , role , self .entry_point , source_dir = self ._model_source_dir (),
316- enable_cloudwatch_metrics = self .enable_cloudwatch_metrics , env = env , image = self .image_name ,
317- name = self ._current_job_name , container_log_level = self .container_log_level ,
331+ if endpoint_type == 'tensorflow-serving' :
332+ return self ._create_tfs_model (role = role , vpc_config_override = vpc_config_override )
333+
334+ return self ._create_default_model (model_server_workers = model_server_workers , role = role ,
335+ vpc_config_override = vpc_config_override )
336+
337+ def _create_tfs_model (self , role = None , vpc_config_override = VPC_CONFIG_DEFAULT ):
338+ return Model (model_data = self .model_data ,
339+ role = role ,
340+ image = self .image_name ,
341+ name = self ._current_job_name ,
342+ container_log_level = self .container_log_level ,
343+ framework_version = self .framework_version ,
344+ sagemaker_session = self .sagemaker_session ,
345+ vpc_config = self .get_vpc_config (vpc_config_override ))
346+
347+ def _create_default_model (self , model_server_workers , role , vpc_config_override ):
348+ return TensorFlowModel (self .model_data , role , self .entry_point ,
349+ source_dir = self ._model_source_dir (),
350+ enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
351+ env = {'SAGEMAKER_REQUIREMENTS' : self .requirements_file },
352+ image = self .image_name ,
353+ name = self ._current_job_name ,
354+ container_log_level = self .container_log_level ,
318355 code_location = self .code_location , py_version = self .py_version ,
319- framework_version = self .framework_version , model_server_workers = model_server_workers ,
356+ framework_version = self .framework_version ,
357+ model_server_workers = model_server_workers ,
320358 sagemaker_session = self .sagemaker_session ,
321359 vpc_config = self .get_vpc_config (vpc_config_override ))
322360
0 commit comments