4343 ENDPOINT_CONFIG_ASYNC_KMS_KEY_ID_PATH ,
4444 load_sagemaker_config ,
4545)
46+ from sagemaker .model_card .schema_constraints import ModelApprovalStatusEnum
4647from sagemaker .session import Session
4748from sagemaker .model_metrics import ModelMetrics
4849from sagemaker .deprecations import removed_kwargs
@@ -374,12 +375,14 @@ def __init__(
374375 self .dependencies = updates ["dependencies" ]
375376 self .uploaded_code = None
376377 self .repacked_model_data = None
378+ self .content_types = None
379+ self .response_types = None
377380
378381 @runnable_by_pipeline
379382 def register (
380383 self ,
381- content_types : List [Union [str , PipelineVariable ]],
382- response_types : List [Union [str , PipelineVariable ]],
384+ content_types : List [Union [str , PipelineVariable ]] = None ,
385+ response_types : List [Union [str , PipelineVariable ]] = None ,
383386 inference_instances : Optional [List [Union [str , PipelineVariable ]]] = None ,
384387 transform_instances : Optional [List [Union [str , PipelineVariable ]]] = None ,
385388 model_package_name : Optional [Union [str , PipelineVariable ]] = None ,
@@ -456,16 +459,33 @@ def register(
456459 in case the Model instance is built with
457460 :class:`~sagemaker.workflow.pipeline_context.PipelineSession`
458461 """
459- if self .model_data is None :
460- raise ValueError ("SageMaker Model Package cannot be created without model data." )
461462 if isinstance (self .model_data , dict ):
462463 raise ValueError (
463464 "SageMaker Model Package currently cannot be created with ModelDataSource."
464465 )
465466
467+ if content_types is not None :
468+ self .content_types = content_types
469+
470+ if response_types is not None :
471+ self .response_types = response_types
472+
473+ if self .content_types is None :
474+ raise ValueError ("The supported MIME types for the input data is not set" )
475+
476+ if self .response_types is None :
477+ raise ValueError ("The supported MIME types for the output data is not set" )
478+
466479 if image_uri is not None :
467480 self .image_uri = image_uri
468481
482+ if model_package_group_name is None and model_package_name is None :
483+ # If model package group and model package name is not set
484+ # then register to auto-generated model package group
485+ model_package_group_name = utils .base_name_from_image (
486+ self .image_uri , default_base_name = ModelPackage .__name__
487+ )
488+
469489 if model_package_group_name is not None :
470490 container_def = self .prepare_container_def ()
471491 container_def = update_container_with_inference_params (
@@ -478,12 +498,14 @@ def register(
478498 else :
479499 container_def = {
480500 "Image" : self .image_uri ,
481- "ModelDataUrl" : self .model_data ,
482501 }
483502
503+ if self .model_data is not None :
504+ container_def ["ModelDataUrl" ] = self .model_data
505+
484506 model_pkg_args = sagemaker .get_model_package_args (
485- content_types ,
486- response_types ,
507+ self . content_types ,
508+ self . response_types ,
487509 inference_instances = inference_instances ,
488510 transform_instances = transform_instances ,
489511 model_package_name = model_package_name ,
@@ -511,6 +533,7 @@ def register(
511533 role = self .role ,
512534 model_data = self .model_data ,
513535 model_package_arn = model_package .get ("ModelPackageArn" ),
536+ sagemaker_session = self .sagemaker_session ,
514537 )
515538
516539 @runnable_by_pipeline
@@ -1751,6 +1774,7 @@ def __init__(
17511774
17521775# works for MODEL_PACKAGE_ARN with or without version info.
17531776MODEL_PACKAGE_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)(?:/(\d+))?$"
1777+ MODEL_PACKAGE_VERSIONED_ARN_PATTERN = r"arn:aws:sagemaker:(.*?):(.*?):model-package/(.*?)/(\d+)$"
17541778
17551779
17561780class ModelPackage (Model ):
@@ -1885,6 +1909,18 @@ def _create_sagemaker_model(self, *args, **kwargs): # pylint: disable=unused-ar
18851909 self ._ensure_base_name_if_needed (model_package_name )
18861910 self ._set_model_name_if_needed ()
18871911
1912+ # Quering the approval status for the model package
1913+ # Approving the versioned model package in case it is not approved
1914+ model_package_desc = self .sagemaker_session .sagemaker_client .describe_model_package (
1915+ ModelPackageName = self .model_package_arn or model_package_name
1916+ )
1917+ if self .model_package_arn is None :
1918+ self .model_package_arn = model_package_desc ["ModelPackageArn" ]
1919+ if re .match (MODEL_PACKAGE_VERSIONED_ARN_PATTERN , self .model_package_arn ):
1920+ approval_status = model_package_desc .get ("ModelApprovalStatus" , "" )
1921+ if approval_status != ModelApprovalStatusEnum .APPROVED :
1922+ self .update_approval_status (approval_status = ModelApprovalStatusEnum .APPROVED )
1923+
18881924 self .sagemaker_session .create_model (
18891925 self .name ,
18901926 self .role ,
@@ -1898,3 +1934,29 @@ def _ensure_base_name_if_needed(self, base_name):
18981934 """Set the base name if there is no model name provided."""
18991935 if self .name is None :
19001936 self ._base_name = base_name
1937+
1938+ def update_approval_status (self , approval_status , approval_description = None ):
1939+ """Update the approval status for the model package
1940+
1941+ Args:
1942+ approval_status (str or PipelineVariable): Model Approval Status, values can be
1943+ "Approved", "Rejected", or "PendingManualApproval".
1944+ approval_description (str): Optional. Description for the approval status of the model
1945+ (default: None).
1946+ """
1947+
1948+ # Models can lazy-init sagemaker_session until deploy() is called to support
1949+ # LocalMode so we must make sure we have an actual session
1950+ sagemaker_session = self .sagemaker_session or sagemaker .Session ()
1951+ if self .model_package_arn is None :
1952+ raise ValueError ("model_package_arn is required to update the status." )
1953+
1954+ update_approval_args = {
1955+ "ModelPackageArn" : self .model_package_arn ,
1956+ "ModelApprovalStatus" : approval_status ,
1957+ }
1958+
1959+ if approval_description is not None :
1960+ update_approval_args ["ApprovalDescription" ] = approval_description
1961+
1962+ sagemaker_session .sagemaker_client .update_model_package (** update_approval_args )
0 commit comments