@@ -215,6 +215,10 @@ def _compilation_job_config(
215215 job_name ,
216216 framework ,
217217 tags ,
218+ target_platform_os = None ,
219+ target_platform_arch = None ,
220+ target_platform_accelerator = None ,
221+ compiler_options = None ,
218222 ):
219223 """
220224 Args:
@@ -226,20 +230,46 @@ def _compilation_job_config(
226230 job_name:
227231 framework:
228232 tags:
233+ target_platform_os:
234+ target_platform_arch:
235+ target_platform_accelerator:
236+ compiler_options:
229237 """
230238 input_model_config = {
231239 "S3Uri" : self .model_data ,
232- "DataInputConfig" : input_shape
233- if not isinstance (input_shape , dict )
234- else json . dumps ( input_shape ) ,
240+ "DataInputConfig" : json . dumps ( input_shape )
241+ if isinstance (input_shape , dict )
242+ else input_shape ,
235243 "Framework" : framework ,
236244 }
237245 role = self .sagemaker_session .expand_role (role )
238246 output_model_config = {
239- "TargetDevice" : target_instance_type ,
240247 "S3OutputLocation" : output_path ,
241248 }
242249
250+ if target_instance_type is not None :
251+ output_model_config ["TargetDevice" ] = target_instance_type
252+ else :
253+ if target_platform_os is None and target_platform_arch is None :
254+ raise ValueError (
255+ "target_instance_type or (target_platform_os and target_platform_arch) "
256+ "should be provided"
257+ )
258+ target_platform = {
259+ "Os" : target_platform_os ,
260+ "Arch" : target_platform_arch ,
261+ }
262+ if target_platform_accelerator is not None :
263+ target_platform ["Accelerator" ] = target_platform_accelerator
264+ output_model_config ["TargetPlatform" ] = target_platform
265+
266+ if compiler_options is not None :
267+ output_model_config ["CompilerOptions" ] = (
268+ json .dumps (compiler_options )
269+ if isinstance (compiler_options , dict )
270+ else compiler_options
271+ )
272+
243273 return {
244274 "input_model_config" : input_model_config ,
245275 "output_model_config" : output_model_config ,
@@ -320,6 +350,10 @@ def compile(
320350 compile_max_run = 5 * 60 ,
321351 framework = None ,
322352 framework_version = None ,
353+ target_platform_os = None ,
354+ target_platform_arch = None ,
355+ target_platform_accelerator = None ,
356+ compiler_options = None ,
323357 ):
324358 """Compile this ``Model`` with SageMaker Neo.
325359
@@ -328,6 +362,9 @@ def compile(
328362 run your model after compilation, for example: ml_c5. For allowed
329363 strings see
330364 https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
365+ Alternatively, you can select an OS, Architecture and Accelerator using
366+ ``target_platform_os``, ``target_platform_arch``,
367+ and ``target_platform_accelerator``.
331368 input_shape (dict): Specifies the name and shape of the expected
332369 inputs for your trained model in json dictionary form, for
333370 example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
@@ -345,6 +382,21 @@ def compile(
345382 model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
346383 'onnx', 'xgboost'
347384 framework_version (str):
385+ target_platform_os (str): Target Platform OS, for example: 'LINUX'.
386+ For allowed strings see
387+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
388+ It can be used instead of target_instance_family.
389+ target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
390+ For allowed strings see
391+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
392+ It can be used instead of target_instance_family.
393+ target_platform_accelerator (str, optional): Target Platform Accelerator,
394+ for example: 'NVIDIA'. For allowed strings see
395+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
396+ It can be used instead of target_instance_family.
397+ compiler_options (dict, optional): Additional parameters for compiler.
398+ Compiler Options are TargetPlatform / target_instance_family specific. See
399+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
348400
349401 Returns:
350402 sagemaker.model.Model: A SageMaker ``Model`` object. See
@@ -375,31 +427,41 @@ def compile(
375427 job_name ,
376428 framework ,
377429 tags ,
430+ target_platform_os ,
431+ target_platform_arch ,
432+ target_platform_accelerator ,
433+ compiler_options ,
378434 )
379435 self .sagemaker_session .compile_model (** config )
380436 job_status = self .sagemaker_session .wait_for_compilation_job (job_name )
381437 self .model_data = job_status ["ModelArtifacts" ]["S3ModelArtifacts" ]
382- if target_instance_family .startswith ("ml_" ):
383- self .image = self ._neo_image (
384- self .sagemaker_session .boto_region_name ,
385- target_instance_family ,
386- framework ,
387- framework_version ,
388- )
389- self ._is_compiled_model = True
390- elif target_instance_family .startswith (INFERENTIA_INSTANCE_PREFIX ):
391- self .image = self ._inferentia_image (
392- self .sagemaker_session .boto_region_name ,
393- target_instance_family ,
394- framework ,
395- framework_version ,
396- )
397- self ._is_compiled_model = True
438+ if target_instance_family is not None :
439+ if target_instance_family .startswith ("ml_" ):
440+ self .image = self ._neo_image (
441+ self .sagemaker_session .boto_region_name ,
442+ target_instance_family ,
443+ framework ,
444+ framework_version ,
445+ )
446+ self ._is_compiled_model = True
447+ elif target_instance_family .startswith (INFERENTIA_INSTANCE_PREFIX ):
448+ self .image = self ._inferentia_image (
449+ self .sagemaker_session .boto_region_name ,
450+ target_instance_family ,
451+ framework ,
452+ framework_version ,
453+ )
454+ self ._is_compiled_model = True
455+ else :
456+ LOGGER .warning (
457+ "The instance type %s is not supported for deployment via SageMaker."
458+ "Please deploy the model manually." ,
459+ target_instance_family ,
460+ )
398461 else :
399462 LOGGER .warning (
400- "The instance type %s is not supported to deploy via SageMaker,"
401- "please deploy the model manually." ,
402- target_instance_family ,
463+ "Devices described by Target Platform OS, Architecture and Accelerator are not"
464+ "supported for deployment via SageMaker. Please deploy the model manually."
403465 )
404466 return self
405467
0 commit comments