1414from __future__ import absolute_import
1515
1616import logging
17+ import re
1718
1819from typing import List , Dict , Optional
19-
2020import sagemaker
21-
2221from sagemaker .parameter import CategoricalParameter
2322
2423INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
@@ -101,13 +100,15 @@ def right_size(
101100 'OMP_NUM_THREADS': CategoricalParameter(['1', '2', '3', '4'])
102101 }]
103102
104- phases (list[Phase]): Specifies the criteria for increasing load
105- during endpoint load tests. (default: None).
106- traffic_type (str): Specifies the traffic type that matches the phases. (default: None).
107- max_invocations (str): defines invocation limit for endpoint load tests (default: None).
108- model_latency_thresholds (list[ModelLatencyThreshold]): defines the response latency
109- thresholds for endpoint load tests (default: None).
110- max_tests (int): restricts how many endpoints are allowed to be
103+ phases (list[Phase]): Shape of the traffic pattern to use in the load test
104+ (default: None).
105+ traffic_type (str): Specifies the traffic pattern type. Currently only supports
106+ one type 'PHASES' (default: None).
107+ max_invocations (str): defines the minimum invocations per minute for the endpoint
108+ to support (default: None).
109+ model_latency_thresholds (list[ModelLatencyThreshold]): defines the maximum response
110+ latency for endpoints to support (default: None).
111+ max_tests (int): restricts how many endpoints in total are allowed to be
111112 spun up for this job (default: None).
112113 max_parallel_tests (int): restricts how many concurrent endpoints
113114 this job is allowed to spin up (default: None).
@@ -122,7 +123,7 @@ def right_size(
122123 raise ValueError ("right_size() is currently only supported with a registered model" )
123124
124125 if not framework and self ._framework ():
125- framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING .get (self ._framework , framework )
126+ framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING .get (self ._framework () , framework )
126127
127128 framework_version = self ._get_framework_version ()
128129
@@ -176,7 +177,38 @@ def right_size(
176177
177178 return self
178179
179- def _check_inference_recommender_args (
180+ def _update_params (
181+ self ,
182+ ** kwargs ,
183+ ):
184+ """Check and update params based on inference recommendation id or right size case"""
185+ instance_type = kwargs ["instance_type" ]
186+ initial_instance_count = kwargs ["initial_instance_count" ]
187+ accelerator_type = kwargs ["accelerator_type" ]
188+ async_inference_config = kwargs ["async_inference_config" ]
189+ serverless_inference_config = kwargs ["serverless_inference_config" ]
190+ inference_recommendation_id = kwargs ["inference_recommendation_id" ]
191+ inference_recommender_job_results = kwargs ["inference_recommender_job_results" ]
192+ if inference_recommendation_id is not None :
193+ inference_recommendation = self ._update_params_for_recommendation_id (
194+ instance_type = instance_type ,
195+ initial_instance_count = initial_instance_count ,
196+ accelerator_type = accelerator_type ,
197+ async_inference_config = async_inference_config ,
198+ serverless_inference_config = serverless_inference_config ,
199+ inference_recommendation_id = inference_recommendation_id ,
200+ )
201+ elif inference_recommender_job_results is not None :
202+ inference_recommendation = self ._update_params_for_right_size (
203+ instance_type ,
204+ initial_instance_count ,
205+ accelerator_type ,
206+ serverless_inference_config ,
207+ async_inference_config ,
208+ )
209+ return inference_recommendation or (instance_type , initial_instance_count )
210+
211+ def _update_params_for_right_size (
180212 self ,
181213 instance_type = None ,
182214 initial_instance_count = None ,
@@ -232,6 +264,161 @@ def _check_inference_recommender_args(
232264 ]
233265 return (instance_type , initial_instance_count )
234266
267+ def _update_params_for_recommendation_id (
268+ self ,
269+ instance_type ,
270+ initial_instance_count ,
271+ accelerator_type ,
272+ async_inference_config ,
273+ serverless_inference_config ,
274+ inference_recommendation_id ,
275+ ):
276+ """Update parameters with inference recommendation results.
277+
278+ Args:
279+ instance_type (str): The EC2 instance type to deploy this Model to.
280+ For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
281+ serverless inference, then it is required to deploy a model.
282+ initial_instance_count (int): The initial number of instances to run
283+ in the ``Endpoint`` created from this ``Model``. If not using
284+ serverless inference, then it need to be a number larger or equals
285+ to 1.
286+ accelerator_type (str): Type of Elastic Inference accelerator to
287+ deploy this model for model loading and inference, for example,
288+ 'ml.eia1.medium'. If not specified, no Elastic Inference
289+ accelerator will be attached to the endpoint. For more
290+ information:
291+ https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
292+ async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
293+ configuration related to async endpoint. Use this configuration when trying
294+ to create async endpoint and make async inference. If empty config object
295+ passed through, will use default config to deploy async endpoint. Deploy a
296+ real-time endpoint if it's None.
297+ serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
298+ Specifies configuration related to serverless endpoint. Use this configuration
299+ when trying to create serverless endpoint and make serverless inference. If
300+ empty object passed through, will use pre-defined values in
301+ ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
302+ instance based endpoint if it's None.
303+ inference_recommendation_id (str): The recommendation id which specifies
304+ the recommendation you picked from inference recommendation job
305+ results and would like to deploy the model and endpoint with
306+ recommended parameters.
307+ Raises:
308+ ValueError: If arguments combination check failed in these circumstances:
309+ - If only one of instance type or instance count specified or
310+ - If recommendation id does not follow the required format or
311+ - If recommendation id is not valid or
312+ - If inference recommendation id is specified along with incompatible parameters
313+ Returns:
314+ (string, int): instance type and associated instance count from selected
315+ inference recommendation id if arguments combination check passed.
316+ """
317+
318+ if instance_type is not None and initial_instance_count is not None :
319+ LOGGER .warning (
320+ "Both instance_type and initial_instance_count are specified,"
321+ "overriding the recommendation result."
322+ )
323+ return (instance_type , initial_instance_count )
324+
325+ # Validate non-compatible parameters with recommendation id
326+ if bool (instance_type ) != bool (initial_instance_count ):
327+ raise ValueError (
328+ "Please either do not specify instance_type and initial_instance_count"
329+ "since they are in recommendation, or specify both of them if you want"
330+ "to override the recommendation."
331+ )
332+ if accelerator_type is not None :
333+ raise ValueError ("accelerator_type is not compatible with inference_recommendation_id." )
334+ if async_inference_config is not None :
335+ raise ValueError (
336+ "async_inference_config is not compatible with inference_recommendation_id."
337+ )
338+ if serverless_inference_config is not None :
339+ raise ValueError (
340+ "serverless_inference_config is not compatible with inference_recommendation_id."
341+ )
342+
343+ # Validate recommendation id
344+ if not re .match (r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$" , inference_recommendation_id ):
345+ raise ValueError ("Inference Recommendation id is not valid" )
346+ recommendation_job_name = inference_recommendation_id .split ("/" )[0 ]
347+
348+ sage_client = self .sagemaker_session .sagemaker_client
349+ recommendation_res = sage_client .describe_inference_recommendations_job (
350+ JobName = recommendation_job_name
351+ )
352+ input_config = recommendation_res ["InputConfig" ]
353+
354+ recommendation = next (
355+ (
356+ rec
357+ for rec in recommendation_res ["InferenceRecommendations" ]
358+ if rec ["RecommendationId" ] == inference_recommendation_id
359+ ),
360+ None ,
361+ )
362+
363+ if not recommendation :
364+ raise ValueError (
365+ "inference_recommendation_id does not exist in InferenceRecommendations list"
366+ )
367+
368+ model_config = recommendation ["ModelConfiguration" ]
369+ envs = (
370+ model_config ["EnvironmentParameters" ]
371+ if "EnvironmentParameters" in model_config
372+ else None
373+ )
374+ # Update envs
375+ recommend_envs = {}
376+ if envs is not None :
377+ for env in envs :
378+ recommend_envs [env ["Key" ]] = env ["Value" ]
379+ self .env .update (recommend_envs )
380+
381+ # Update params with non-compilation recommendation results
382+ if (
383+ "InferenceSpecificationName" not in model_config
384+ and "CompilationJobName" not in model_config
385+ ):
386+
387+ if "ModelPackageVersionArn" in input_config :
388+ modelpkg_res = sage_client .describe_model_package (
389+ ModelPackageName = input_config ["ModelPackageVersionArn" ]
390+ )
391+ self .model_data = modelpkg_res ["InferenceSpecification" ]["Containers" ][0 ][
392+ "ModelDataUrl"
393+ ]
394+ self .image_uri = modelpkg_res ["InferenceSpecification" ]["Containers" ][0 ]["Image" ]
395+ elif "ModelName" in input_config :
396+ model_res = sage_client .describe_model (ModelName = input_config ["ModelName" ])
397+ self .model_data = model_res ["PrimaryContainer" ]["ModelDataUrl" ]
398+ self .image_uri = model_res ["PrimaryContainer" ]["Image" ]
399+ else :
400+ if "InferenceSpecificationName" in model_config :
401+ modelpkg_res = sage_client .describe_model_package (
402+ ModelPackageName = input_config ["ModelPackageVersionArn" ]
403+ )
404+ self .model_data = modelpkg_res ["AdditionalInferenceSpecificationDefinition" ][
405+ "Containers"
406+ ][0 ]["ModelDataUrl" ]
407+ self .image_uri = modelpkg_res ["AdditionalInferenceSpecificationDefinition" ][
408+ "Containers"
409+ ][0 ]["Image" ]
410+ elif "CompilationJobName" in model_config :
411+ compilation_res = sage_client .describe_compilation_job (
412+ CompilationJobName = model_config ["CompilationJobName" ]
413+ )
414+ self .model_data = compilation_res ["ModelArtifacts" ]["S3ModelArtifacts" ]
415+ self .image_uri = compilation_res ["InferenceImage" ]
416+
417+ instance_type = recommendation ["EndpointConfiguration" ]["InstanceType" ]
418+ initial_instance_count = recommendation ["EndpointConfiguration" ]["InitialInstanceCount" ]
419+
420+ return (instance_type , initial_instance_count )
421+
235422 def _convert_to_endpoint_configurations_json (
236423 self , hyperparameter_ranges : List [Dict [str , CategoricalParameter ]]
237424 ):
0 commit comments