1212# language governing permissions and limitations under the License.
1313from __future__ import absolute_import
1414
15+ import json
1516import logging
1617
1718import sagemaker
1819from sagemaker import fw_utils , local , session , utils
1920
21+ NEO_ALLOWED_TARGET_INSTANCE_FAMILY = set (['ml_c5' , 'ml_m5' , 'ml_c4' , 'ml_m4' , 'jetson_tx1' , 'jetson_tx2' , 'ml_p2' ,
22+ 'ml_p3' , 'deeplens' , 'rasp3b' ])
23+ NEO_ALLOWED_FRAMEWORKS = set (['mxnet' , 'tensorflow' , 'pytorch' , 'onnx' , 'xgboost' ])
24+
25+ NEO_IMAGE_ACCOUNT = {
26+ 'us-west-2' : '301217895009' ,
27+ 'us-east-1' : '785573368785' ,
28+ 'eu-west-1' : '802834080501' ,
29+ 'us-east-2' : '007439368137'
30+ }
31+
2032
2133class Model (object ):
2234 """A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
@@ -53,6 +65,7 @@ def __init__(self, model_data, image, role=None, predictor_cls=None, env=None, n
5365 self .vpc_config = vpc_config
5466 self .sagemaker_session = sagemaker_session
5567 self ._model_name = None
68+ self ._is_compiled_model = False
5669
5770 def prepare_container_def (self , instance_type ): # pylint: disable=unused-argument
5871 """Return a dict created by ``sagemaker.container_def()`` for deploying this model to a specified instance type.
@@ -68,6 +81,93 @@ def prepare_container_def(self, instance_type): # pylint: disable=unused-argume
6881 """
6982 return sagemaker .container_def (self .image , self .model_data , self .env )
7083
84+ def _framework (self ):
85+ return getattr (self , '__framework_name__' , None )
86+
87+ def _get_framework_version (self ):
88+ return getattr (self , 'framework_version' , None )
89+
90+ def _compilation_job_config (self , target_instance_type , input_shape , output_path , role , compile_max_run ,
91+ job_name , framework , tags ):
92+ input_model_config = {
93+ 'S3Uri' : self .model_data ,
94+ 'DataInputConfig' : input_shape if type (input_shape ) != dict else json .dumps (input_shape ),
95+ 'Framework' : framework
96+ }
97+ role = self .sagemaker_session .expand_role (role )
98+ output_model_config = {
99+ 'TargetDevice' : target_instance_type ,
100+ 'S3OutputLocation' : output_path
101+ }
102+
103+ return {'input_model_config' : input_model_config ,
104+ 'output_model_config' : output_model_config ,
105+ 'role' : role ,
106+ 'stop_condition' : {
107+ 'MaxRuntimeInSeconds' : compile_max_run
108+ },
109+ 'tags' : tags ,
110+ 'job_name' : job_name }
111+
112+ def _neo_image_account (self , region ):
113+ if region not in NEO_IMAGE_ACCOUNT :
114+ raise ValueError ("Neo is not currently supported in {}, "
115+ "valid regions: {}" .format (region , NEO_IMAGE_ACCOUNT .keys ()))
116+ return NEO_IMAGE_ACCOUNT [region ]
117+
118+ def _neo_image (self , region , target_instance_type , framework , framework_version ):
119+ return fw_utils .create_image_uri (region ,
120+ 'neo-' + framework .lower (),
121+ target_instance_type .replace ('_' , '.' ),
122+ framework_version ,
123+ py_version = 'py3' ,
124+ account = self ._neo_image_account (region ))
125+
126+ def compile (self , target_instance_family , input_shape , output_path , role ,
127+ tags = None , job_name = None , compile_max_run = 5 * 60 , framework = None , framework_version = None ):
128+ """Compile this ``Model`` with SageMaker Neo.
129+
130+ Args:
131+ target_instance_family (str): Identifies the device that you want to run your model after compilation, for
132+ example: ml_c5. Allowed strings are: ml_c5, ml_m5, ml_c4, ml_m4, jetsontx1, jetsontx2, ml_p2, ml_p3,
133+ deeplens, rasp3b
134+ input_shape (dict): Specifies the name and shape of the expected inputs for your trained model in json
135+ dictionary form, for example: {‘data’:[1,3,1024,1024]}, or {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]}
136+ output_path (str): Specifies where to store the compiled model
137+ role (str): Execution role
138+ tags (list[dict]): List of tags for labeling a compilation job. For more, see
139+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
140+ job_name (str): The name of the compilation job
141+ compile_max_run (int): Timeout in seconds for compilation (default: 3 * 60).
142+ After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its
143+ current status.
144+ framework (str): The framework that is used to train the original model. Allowed values: 'mxnet',
145+ 'tensorflow', 'pytorch', 'onnx', 'xgboost'
146+ framework_version (str)
147+ Returns:
148+ sagemaker.model.Model: A SageMaker ``Model`` object. See :func:`~sagemaker.model.Model` for full details.
149+ """
150+ framework = self ._framework () or framework
151+ if framework is None :
152+ raise ValueError ("You must specify framework, allowed values {}" .format (NEO_ALLOWED_FRAMEWORKS ))
153+ if framework not in NEO_ALLOWED_FRAMEWORKS :
154+ raise ValueError ("You must provide valid framework, allowed values {}" .format (NEO_ALLOWED_FRAMEWORKS ))
155+ if job_name is None :
156+ raise ValueError ("You must provide a compilation job name" )
157+
158+ framework = framework .upper ()
159+ framework_version = self ._get_framework_version () or framework_version
160+
161+ config = self ._compilation_job_config (target_instance_family , input_shape , output_path , role ,
162+ compile_max_run , job_name , framework , tags )
163+ self .sagemaker_session .compile_model (** config )
164+ job_status = self .sagemaker_session .wait_for_compilation_job (job_name )
165+ self .model_data = job_status ['ModelArtifacts' ]['S3ModelArtifacts' ]
166+ self .image = self ._neo_image (self .sagemaker_session .boto_region_name , target_instance_family , framework ,
167+ framework_version )
168+ self ._is_compiled_model = True
169+ return self
170+
71171 def deploy (self , initial_instance_count , instance_type , endpoint_name = None , tags = None ):
72172 """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
73173
@@ -98,13 +198,21 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
98198 else :
99199 self .sagemaker_session = session .Session ()
100200
201+ compiled_model_suffix = '-' .join (instance_type .split ('.' )[:- 1 ])
101202 container_def = self .prepare_container_def (instance_type )
102203 self .name = self .name or utils .name_from_image (container_def ['Image' ])
103204 if self .role is None :
104205 raise ValueError ("Role can not be null for deploying a model" )
206+ if self ._is_compiled_model :
207+ self .name += compiled_model_suffix
105208 self .sagemaker_session .create_model (self .name , self .role , container_def , vpc_config = self .vpc_config )
106209 production_variant = sagemaker .production_variant (self .name , instance_type , initial_instance_count )
107- self .endpoint_name = endpoint_name or self .name
210+ if endpoint_name :
211+ self .endpoint_name = endpoint_name
212+ else :
213+ self .endpoint_name = self .name
214+ if self ._is_compiled_model and not self .endpoint_name .endswith (compiled_model_suffix ):
215+ self .endpoint_name += compiled_model_suffix
108216 self .sagemaker_session .endpoint_from_production_variants (self .endpoint_name , [production_variant ], tags )
109217 if self .predictor_cls :
110218 return self .predictor_cls (self .endpoint_name , self .sagemaker_session )
0 commit comments