@@ -376,8 +376,8 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
376376 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the model
377377 image (str): An container image to use for deploying the model
378378 model_server_workers (int): The number of worker processes used by the inference server.
379- If None, server will use one worker per vCPU. Only effective when estimator is
380- SageMaker framework.
379+ If None, server will use one worker per vCPU. Only effective when estimator is a
380+ SageMaker framework.
381381 vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
382382 Default: use subnets and security groups from this Estimator.
383383 * 'Subnets' (list[str]): List of subnet ids.
@@ -394,5 +394,223 @@ def model_config_from_estimator(instance_type, estimator, role=None, image=None,
394394 elif isinstance (estimator , sagemaker .estimator .Framework ):
395395 model = estimator .create_model (model_server_workers = model_server_workers , role = role ,
396396 vpc_config_override = vpc_config_override )
397+ else :
398+ raise TypeError ('Estimator must be one of sagemaker.estimator.Estimator, sagemaker.estimator.Framework'
399+ ' or sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.' )
397400
398401 return model_config (instance_type , model , role , image )
402+
403+
404+ def transform_config (transformer , data , data_type = 'S3Prefix' , content_type = None , compression_type = None ,
405+ split_type = None , job_name = None ):
406+ """Export Airflow transform config from a SageMaker transformer
407+
408+ Args:
409+ transformer (sagemaker.transformer.Transformer): The SageMaker transformer to export Airflow
410+ config from.
411+ data (str): Input data location in S3.
412+ data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:
413+
414+ * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
415+ inputs for the transform job.
416+ * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
417+ an input for the transform job.
418+
419+ content_type (str): MIME type of the input data (default: None).
420+ compression_type (str): Compression type of the input data, if compressed (default: None).
421+ Valid values: 'Gzip', None.
422+ split_type (str): The record delimiter for the input object (default: 'None').
423+ Valid values: 'None', 'Line', and 'RecordIO'.
424+ job_name (str): job name (default: None). If not specified, one will be generated.
425+
426+ Returns:
427+ dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow.
428+ """
429+ if job_name is not None :
430+ transformer ._current_job_name = job_name
431+ else :
432+ base_name = transformer .base_transform_job_name
433+ transformer ._current_job_name = utils .airflow_name_from_base (base_name ) \
434+ if base_name is not None else transformer .model_name
435+
436+ if transformer .output_path is None :
437+ transformer .output_path = 's3://{}/{}' .format (
438+ transformer .sagemaker_session .default_bucket (), transformer ._current_job_name )
439+
440+ job_config = sagemaker .transformer ._TransformJob ._load_config (
441+ data , data_type , content_type , compression_type , split_type , transformer )
442+
443+ config = {
444+ 'TransformJobName' : transformer ._current_job_name ,
445+ 'ModelName' : transformer .model_name ,
446+ 'TransformInput' : job_config ['input_config' ],
447+ 'TransformOutput' : job_config ['output_config' ],
448+ 'TransformResources' : job_config ['resource_config' ],
449+ }
450+
451+ if transformer .strategy is not None :
452+ config ['BatchStrategy' ] = transformer .strategy
453+
454+ if transformer .max_concurrent_transforms is not None :
455+ config ['MaxConcurrentTransforms' ] = transformer .max_concurrent_transforms
456+
457+ if transformer .max_payload is not None :
458+ config ['MaxPayloadInMB' ] = transformer .max_payload
459+
460+ if transformer .env is not None :
461+ config ['Environment' ] = transformer .env
462+
463+ if transformer .tags is not None :
464+ config ['Tags' ] = transformer .tags
465+
466+ return config
467+
468+
469+ def transform_config_from_estimator (estimator , instance_count , instance_type , data , data_type = 'S3Prefix' ,
470+ content_type = None , compression_type = None , split_type = None ,
471+ job_name = None , strategy = None , assemble_with = None , output_path = None ,
472+ output_kms_key = None , accept = None , env = None , max_concurrent_transforms = None ,
473+ max_payload = None , tags = None , role = None , volume_kms_key = None ,
474+ model_server_workers = None , image = None , vpc_config_override = None ):
475+ """Export Airflow transform config from a SageMaker estimator
476+
477+ Args:
478+ estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
479+ It has to be an estimator associated with a training job.
480+ instance_count (int): Number of EC2 instances to use.
481+ instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
482+ data (str): Input data location in S3.
483+ data_type (str): What the S3 location defines (default: 'S3Prefix'). Valid values:
484+
485+ * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix will be used as
486+ inputs for the transform job.
487+ * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object to use as
488+ an input for the transform job.
489+
490+ content_type (str): MIME type of the input data (default: None).
491+ compression_type (str): Compression type of the input data, if compressed (default: None).
492+ Valid values: 'Gzip', None.
493+ split_type (str): The record delimiter for the input object (default: 'None').
494+ Valid values: 'None', 'Line', and 'RecordIO'.
495+ job_name (str): job name (default: None). If not specified, one will be generated.
496+ strategy (str): The strategy used to decide how to batch records in a single request (default: None).
497+ Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
498+ assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
499+ output_path (str): S3 location for saving the transform result. If not specified, results are stored to
500+ a default bucket.
501+ output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
502+ accept (str): The content type accepted by the endpoint deployed during the transform job.
503+ env (dict): Environment variables to be set for use during the transform job (default: None).
504+ max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
505+ each individual transform container at one time.
506+ max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
507+ tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
508+ the training job are used for the transform job.
509+ role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
510+ transform jobs. If not specified, the role from the Estimator will be used.
511+ volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
512+ compute instance (default: None).
513+ model_server_workers (int): Optional. The number of worker processes used by the inference server.
514+ If None, server will use one worker per vCPU.
515+ image (str): An container image to use for deploying the model
516+ vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on the model.
517+ Default: use subnets and security groups from this Estimator.
518+ * 'Subnets' (list[str]): List of subnet ids.
519+ * 'SecurityGroupIds' (list[str]): List of security group ids.
520+
521+ Returns:
522+ dict: Transform config that can be directly used by SageMakerTransformOperator in Airflow.
523+ """
524+ model_base_config = model_config_from_estimator (instance_type = instance_type , estimator = estimator , role = role ,
525+ image = image , model_server_workers = model_server_workers ,
526+ vpc_config_override = vpc_config_override )
527+
528+ if isinstance (estimator , sagemaker .estimator .Framework ):
529+ transformer = estimator .transformer (instance_count , instance_type , strategy , assemble_with , output_path ,
530+ output_kms_key , accept , env , max_concurrent_transforms ,
531+ max_payload , tags , role , model_server_workers , volume_kms_key )
532+ else :
533+ transformer = estimator .transformer (instance_count , instance_type , strategy , assemble_with , output_path ,
534+ output_kms_key , accept , env , max_concurrent_transforms ,
535+ max_payload , tags , role , volume_kms_key )
536+
537+ transform_base_config = transform_config (transformer , data , data_type , content_type , compression_type ,
538+ split_type , job_name )
539+
540+ config = {
541+ 'Model' : model_base_config ,
542+ 'Transform' : transform_base_config
543+ }
544+
545+ return config
546+
547+
548+ def deploy_config (model , initial_instance_count , instance_type , endpoint_name = None , tags = None ):
549+ """Export Airflow deploy config from a SageMaker model
550+
551+ Args:
552+ model (sagemaker.model.Model): The SageMaker model to export the Airflow config from.
553+ instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
554+ initial_instance_count (int): The initial number of instances to run in the
555+ ``Endpoint`` created from this ``Model``.
556+ endpoint_name (str): The name of the endpoint to create (default: None).
557+ If not specified, a unique endpoint name will be created.
558+ tags (list[dict]): List of tags for labeling a training job. For more, see
559+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
560+
561+ Returns:
562+ dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow.
563+
564+ """
565+ model_base_config = model_config (instance_type , model )
566+
567+ production_variant = sagemaker .production_variant (model .name , instance_type , initial_instance_count )
568+ name = model .name
569+ config_options = {'EndpointConfigName' : name , 'ProductionVariants' : [production_variant ]}
570+ if tags is not None :
571+ config_options ['Tags' ] = tags
572+
573+ endpoint_name = endpoint_name or name
574+ endpoint_base_config = {
575+ 'EndpointName' : endpoint_name ,
576+ 'EndpointConfigName' : name
577+ }
578+
579+ config = {
580+ 'Model' : model_base_config ,
581+ 'EndpointConfig' : config_options ,
582+ 'Endpoint' : endpoint_base_config
583+ }
584+
585+ # if there is s3 operations needed for model, move it to root level of config
586+ s3_operations = model_base_config .pop ('S3Operations' , None )
587+ if s3_operations is not None :
588+ config ['S3Operations' ] = s3_operations
589+
590+ return config
591+
592+
593+ def deploy_config_from_estimator (estimator , initial_instance_count , instance_type , endpoint_name = None ,
594+ tags = None , ** kwargs ):
595+ """Export Airflow deploy config from a SageMaker estimator
596+
597+ Args:
598+ estimator (sagemaker.model.EstimatorBase): The SageMaker estimator to export Airflow config from.
599+ It has to be an estimator associated with a training job.
600+ initial_instance_count (int): Minimum number of EC2 instances to deploy to an endpoint for prediction.
601+ instance_type (str): Type of EC2 instance to deploy to an endpoint for prediction,
602+ for example, 'ml.c4.xlarge'.
603+ endpoint_name (str): Name to use for creating an Amazon SageMaker endpoint. If not specified, the name of
604+ the training job is used.
605+ tags (list[dict]): List of tags for labeling a training job. For more, see
606+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
607+ **kwargs: Passed to invocation of ``create_model()``. Implementations may customize
608+ ``create_model()`` to accept ``**kwargs`` to customize model creation during deploy.
609+ For more, see the implementation docs.
610+
611+ Returns:
612+ dict: Deploy config that can be directly used by SageMakerEndpointOperator in Airflow.
613+ """
614+ model = estimator .create_model (** kwargs )
615+ config = deploy_config (model , initial_instance_count , instance_type , endpoint_name , tags )
616+ return config
0 commit comments