1414from __future__ import absolute_import
1515
1616import logging
17+ from typing import Optional , Union , List , Dict
1718
1819import sagemaker
19- from sagemaker import image_uris
20+ from sagemaker import image_uris , ModelMetrics
2021from sagemaker .deserializers import JSONDeserializer
22+ from sagemaker .drift_check_baselines import DriftCheckBaselines
2123from sagemaker .fw_utils import (
2224 model_code_key_prefix ,
2325 validate_version_or_image_args ,
2426)
27+ from sagemaker .metadata_properties import MetadataProperties
2528from sagemaker .model import FrameworkModel , MODEL_SERVER_WORKERS_PARAM_NAME
2629from sagemaker .predictor import Predictor
2730from sagemaker .serializers import JSONSerializer
2831from sagemaker .session import Session
32+ from sagemaker .utils import to_string
33+ from sagemaker .workflow import is_pipeline_variable
34+ from sagemaker .workflow .entities import PipelineVariable
2935
3036logger = logging .getLogger ("sagemaker" )
3137
@@ -100,16 +106,16 @@ class HuggingFaceModel(FrameworkModel):
100106
101107 def __init__ (
102108 self ,
103- role ,
104- model_data = None ,
105- entry_point = None ,
106- transformers_version = None ,
107- tensorflow_version = None ,
108- pytorch_version = None ,
109- py_version = None ,
110- image_uri = None ,
111- predictor_cls = HuggingFacePredictor ,
112- model_server_workers = None ,
109+ role : str ,
110+ model_data : Optional [ Union [ str , PipelineVariable ]] = None ,
111+ entry_point : Optional [ str ] = None ,
112+ transformers_version : Optional [ str ] = None ,
113+ tensorflow_version : Optional [ str ] = None ,
114+ pytorch_version : Optional [ str ] = None ,
115+ py_version : Optional [ str ] = None ,
116+ image_uri : Optional [ Union [ str , PipelineVariable ]] = None ,
117+ predictor_cls : callable = HuggingFacePredictor ,
118+ model_server_workers : Optional [ Union [ int , PipelineVariable ]] = None ,
113119 ** kwargs ,
114120 ):
115121 """Initialize a HuggingFaceModel.
@@ -299,27 +305,27 @@ def deploy(
299305
300306 def register (
301307 self ,
302- content_types ,
303- response_types ,
304- inference_instances = None ,
305- transform_instances = None ,
306- model_package_name = None ,
307- model_package_group_name = None ,
308- image_uri = None ,
309- model_metrics = None ,
310- metadata_properties = None ,
311- marketplace_cert = False ,
312- approval_status = None ,
313- description = None ,
314- drift_check_baselines = None ,
315- customer_metadata_properties = None ,
316- domain = None ,
317- sample_payload_url = None ,
318- task = None ,
319- framework = None ,
320- framework_version = None ,
321- nearest_model_name = None ,
322- data_input_configuration = None ,
308+ content_types : List [ Union [ str , PipelineVariable ]] ,
309+ response_types : List [ Union [ str , PipelineVariable ]] ,
310+ inference_instances : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
311+ transform_instances : Optional [ List [ Union [ str , PipelineVariable ]]] = None ,
312+ model_package_name : Optional [ Union [ str , PipelineVariable ]] = None ,
313+ model_package_group_name : Optional [ Union [ str , PipelineVariable ]] = None ,
314+ image_uri : Optional [ Union [ str , PipelineVariable ]] = None ,
315+ model_metrics : Optional [ ModelMetrics ] = None ,
316+ metadata_properties : Optional [ MetadataProperties ] = None ,
317+ marketplace_cert : bool = False ,
318+ approval_status : Optional [ Union [ str , PipelineVariable ]] = None ,
319+ description : Optional [ str ] = None ,
320+ drift_check_baselines : Optional [ DriftCheckBaselines ] = None ,
321+ customer_metadata_properties : Optional [ Dict [ str , Union [ str , PipelineVariable ]]] = None ,
322+ domain : Optional [ Union [ str , PipelineVariable ]] = None ,
323+ sample_payload_url : Optional [ Union [ str , PipelineVariable ]] = None ,
324+ task : Optional [ Union [ str , PipelineVariable ]] = None ,
325+ framework : Optional [ Union [ str , PipelineVariable ]] = None ,
326+ framework_version : Optional [ Union [ str , PipelineVariable ]] = None ,
327+ nearest_model_name : Optional [ Union [ str , PipelineVariable ]] = None ,
328+ data_input_configuration : Optional [ Union [ str , PipelineVariable ]] = None ,
323329 ):
324330 """Creates a model package for creating SageMaker models or listing on Marketplace.
325331
@@ -377,6 +383,13 @@ def register(
377383 region_name = self .sagemaker_session .boto_session .region_name ,
378384 instance_type = instance_type ,
379385 )
386+ if not is_pipeline_variable (framework ):
387+ framework = (
388+ framework
389+ or fetch_framework_and_framework_version (
390+ self .tensorflow_version , self .pytorch_version
391+ )[0 ]
392+ ).upper ()
380393 return super (HuggingFaceModel , self ).register (
381394 content_types ,
382395 response_types ,
@@ -395,12 +408,7 @@ def register(
395408 domain = domain ,
396409 sample_payload_url = sample_payload_url ,
397410 task = task ,
398- framework = (
399- framework
400- or fetch_framework_and_framework_version (
401- self .tensorflow_version , self .pytorch_version
402- )[0 ]
403- ).upper (),
411+ framework = framework ,
404412 framework_version = framework_version
405413 or fetch_framework_and_framework_version (self .tensorflow_version , self .pytorch_version )[
406414 1
@@ -449,7 +457,9 @@ def prepare_container_def(
449457 deploy_env .update (self ._script_mode_env_vars ())
450458
451459 if self .model_server_workers :
452- deploy_env [MODEL_SERVER_WORKERS_PARAM_NAME .upper ()] = str (self .model_server_workers )
460+ deploy_env [MODEL_SERVER_WORKERS_PARAM_NAME .upper ()] = to_string (
461+ self .model_server_workers
462+ )
453463 return sagemaker .container_def (
454464 deploy_image , self .repacked_model_data or self .model_data , deploy_env
455465 )
0 commit comments