@@ -544,7 +544,8 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
544544 self .sagemaker_client .create_transform_job (** transform_request )
545545
546546 def create_model (self , name , role , container_defs , vpc_config = None ,
547- enable_network_isolation = False , primary_container = None ):
547+ enable_network_isolation = False , primary_container = None ,
548+ tags = None ):
548549 """Create an Amazon SageMaker ``Model``.
549550 Specify the S3 location of the model artifacts and Docker image containing
550551 the inference code. Amazon SageMaker uses this information to deploy the
@@ -570,6 +571,11 @@ def create_model(self, name, role, container_defs, vpc_config=None,
570571 You can also specify the return value of ``sagemaker.container_def()``, which is used to create
571572 more advanced container configurations, including model containers which need artifacts from S3. This
572573 field is deprecated, please use container_defs instead.
574+ tags(List[dict[str, str]]): Optional. The list of tags to add to the model. Example:
575+ >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
576+ For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
577+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
578+
573579
574580 Returns:
575581 str: Name of the Amazon SageMaker ``Model`` created.
@@ -583,12 +589,16 @@ def create_model(self, name, role, container_defs, vpc_config=None,
583589 container_defs = primary_container
584590
585591 role = self .expand_role (role )
586- create_model_request = {}
592+
587593 if isinstance (container_defs , list ):
588- create_model_request = _create_model_request ( name = name , role = role , container_def = container_defs )
594+ container_definition = container_defs
589595 else :
590- primary_container = _expand_container_def (container_defs )
591- create_model_request = _create_model_request (name = name , role = role , container_def = primary_container )
596+ container_definition = _expand_container_def (container_defs )
597+
598+ create_model_request = _create_model_request (name = name ,
599+ role = role ,
600+ container_def = container_definition ,
601+ tags = tags )
592602
593603 if vpc_config :
594604 create_model_request ['VpcConfig' ] = vpc_config
@@ -702,7 +712,8 @@ def wait_for_model_package(self, model_package_name, poll=5):
702712 model_package_name , status , reason ))
703713 return desc
704714
705- def create_endpoint_config (self , name , model_name , initial_instance_count , instance_type , accelerator_type = None ):
715+ def create_endpoint_config (self , name , model_name , initial_instance_count , instance_type ,
716+ accelerator_type = None , tags = None ):
706717 """Create an Amazon SageMaker endpoint configuration.
707718
708719 The endpoint configuration identifies the Amazon SageMaker model (created using the
@@ -717,17 +728,24 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
717728 instance_type (str): Type of EC2 instance to launch, for example, 'ml.c4.xlarge'.
718729 accelerator_type (str): Type of Elastic Inference accelerator to attach to the instance. For example,
719730 'ml.eia1.medium'. For more information: https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
731+ tags(List[dict[str, str]]): Optional. The list of tags to add to the endpoint config. Example:
732+ >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
733+ For more information about tags, see https://boto3.amazonaws.com/v1/documentation\
734+ /api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
720735
721736
722737 Returns:
723738 str: Name of the endpoint point configuration created.
724739 """
725740 LOGGER .info ('Creating endpoint-config with name {}' .format (name ))
726741
742+ tags = tags or []
743+
727744 self .sagemaker_client .create_endpoint_config (
728745 EndpointConfigName = name ,
729746 ProductionVariants = [production_variant (model_name , instance_type , initial_instance_count ,
730- accelerator_type = accelerator_type )]
747+ accelerator_type = accelerator_type )],
748+ Tags = tags
731749 )
732750 return name
733751
@@ -1383,19 +1401,18 @@ def __init__(self, model_data, image, env=None):
13831401 self .env = env
13841402
13851403
1386- def _create_model_request (name , role , container_def = None ): # pylint: disable=redefined-outer-name
1404+ def _create_model_request (name , role , container_def = None , tags = None ): # pylint: disable=redefined-outer-name
1405+ request = {'ModelName' : name , 'ExecutionRoleArn' : role }
1406+
13871407 if isinstance (container_def , list ):
1388- return {
1389- 'ModelName' : name ,
1390- 'Containers' : container_def ,
1391- 'ExecutionRoleArn' : role
1392- }
1408+ request ['Containers' ] = container_def
13931409 else :
1394- return {
1395- 'ModelName' : name ,
1396- 'PrimaryContainer' : container_def ,
1397- 'ExecutionRoleArn' : role
1398- }
1410+ request ['PrimaryContainer' ] = container_def
1411+
1412+ if tags :
1413+ request ['Tags' ] = tags
1414+
1415+ return request
13991416
14001417
14011418def _deployment_entity_exists (describe_fn ):
0 commit comments