@@ -237,7 +237,12 @@ def _update_params(
237237 async_inference_config ,
238238 explainer_config ,
239239 )
240- return inference_recommendation or (instance_type , initial_instance_count )
240+
241+ return (
242+ inference_recommendation
243+ if inference_recommendation
244+ else (instance_type , initial_instance_count )
245+ )
241246
242247 def _update_params_for_right_size (
243248 self ,
@@ -365,12 +370,6 @@ def _update_params_for_recommendation_id(
365370 return (instance_type , initial_instance_count )
366371
367372 # Validate non-compatible parameters with recommendation id
368- if bool (instance_type ) != bool (initial_instance_count ):
369- raise ValueError (
370- "Please either do not specify instance_type and initial_instance_count"
371- "since they are in recommendation, or specify both of them if you want"
372- "to override the recommendation."
373- )
374373 if accelerator_type is not None :
375374 raise ValueError ("accelerator_type is not compatible with inference_recommendation_id." )
376375 if async_inference_config is not None :
@@ -386,30 +385,38 @@ def _update_params_for_recommendation_id(
386385
387386 # Validate recommendation id
388387 if not re .match (r"[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}\/\w{8}$" , inference_recommendation_id ):
389- raise ValueError ("Inference Recommendation id is not valid" )
390- recommendation_job_name = inference_recommendation_id .split ("/" )[0 ]
388+ raise ValueError ("inference_recommendation_id is not valid" )
389+ job_or_model_name = inference_recommendation_id .split ("/" )[0 ]
391390
392391 sage_client = self .sagemaker_session .sagemaker_client
393- recommendation_res = sage_client .describe_inference_recommendations_job (
394- JobName = recommendation_job_name
392+ # Get recommendation from right size job and model
393+ (
394+ right_size_recommendation ,
395+ model_recommendation ,
396+ right_size_job_res ,
397+ ) = self ._get_recommendation (
398+ sage_client = sage_client ,
399+ job_or_model_name = job_or_model_name ,
400+ inference_recommendation_id = inference_recommendation_id ,
395401 )
396- input_config = recommendation_res ["InputConfig" ]
397402
398- recommendation = next (
399- (
400- rec
401- for rec in recommendation_res ["InferenceRecommendations" ]
402- if rec ["RecommendationId" ] == inference_recommendation_id
403- ),
404- None ,
405- )
403+ # Update params beased on model recommendation
404+ if model_recommendation :
405+ if initial_instance_count is None :
406+ raise ValueError ("Must specify model recommendation id and instance count." )
407+ self .env .update (model_recommendation ["Environment" ])
408+ instance_type = model_recommendation ["InstanceType" ]
409+ return (instance_type , initial_instance_count )
406410
407- if not recommendation :
411+ # Update params based on default inference recommendation
412+ if bool (instance_type ) != bool (initial_instance_count ):
408413 raise ValueError (
409- "inference_recommendation_id does not exist in InferenceRecommendations list"
414+ "instance_type and initial_instance_count are mutually exclusive with"
415+ "recommendation id since they are in recommendation."
416+ "Please specify both of them if you want to override the recommendation."
410417 )
411-
412- model_config = recommendation ["ModelConfiguration" ]
418+ input_config = right_size_job_res [ "InputConfig" ]
419+ model_config = right_size_recommendation ["ModelConfiguration" ]
413420 envs = (
414421 model_config ["EnvironmentParameters" ]
415422 if "EnvironmentParameters" in model_config
@@ -458,8 +465,10 @@ def _update_params_for_recommendation_id(
458465 self .model_data = compilation_res ["ModelArtifacts" ]["S3ModelArtifacts" ]
459466 self .image_uri = compilation_res ["InferenceImage" ]
460467
461- instance_type = recommendation ["EndpointConfiguration" ]["InstanceType" ]
462- initial_instance_count = recommendation ["EndpointConfiguration" ]["InitialInstanceCount" ]
468+ instance_type = right_size_recommendation ["EndpointConfiguration" ]["InstanceType" ]
469+ initial_instance_count = right_size_recommendation ["EndpointConfiguration" ][
470+ "InitialInstanceCount"
471+ ]
463472
464473 return (instance_type , initial_instance_count )
465474
@@ -527,3 +536,77 @@ def _convert_to_stopping_conditions_json(
527536 threshold .to_json for threshold in model_latency_thresholds
528537 ]
529538 return stopping_conditions
539+
540+ def _get_recommendation (self , sage_client , job_or_model_name , inference_recommendation_id ):
541+ """Get recommendation from right size job and model"""
542+ right_size_recommendation , model_recommendation , right_size_job_res = None , None , None
543+ right_size_recommendation , right_size_job_res = self ._get_right_size_recommendation (
544+ sage_client = sage_client ,
545+ job_or_model_name = job_or_model_name ,
546+ inference_recommendation_id = inference_recommendation_id ,
547+ )
548+ if right_size_recommendation is None :
549+ model_recommendation = self ._get_model_recommendation (
550+ sage_client = sage_client ,
551+ job_or_model_name = job_or_model_name ,
552+ inference_recommendation_id = inference_recommendation_id ,
553+ )
554+ if model_recommendation is None :
555+ raise ValueError ("inference_recommendation_id is not valid" )
556+
557+ return right_size_recommendation , model_recommendation , right_size_job_res
558+
559+ def _get_right_size_recommendation (
560+ self ,
561+ sage_client ,
562+ job_or_model_name ,
563+ inference_recommendation_id ,
564+ ):
565+ """Get recommendation from right size job"""
566+ right_size_recommendation , right_size_job_res = None , None
567+ try :
568+ right_size_job_res = sage_client .describe_inference_recommendations_job (
569+ JobName = job_or_model_name
570+ )
571+ if right_size_job_res :
572+ right_size_recommendation = self ._search_recommendation (
573+ recommendation_list = right_size_job_res ["InferenceRecommendations" ],
574+ inference_recommendation_id = inference_recommendation_id ,
575+ )
576+ except sage_client .exceptions .ResourceNotFound :
577+ pass
578+
579+ return right_size_recommendation , right_size_job_res
580+
581+ def _get_model_recommendation (
582+ self ,
583+ sage_client ,
584+ job_or_model_name ,
585+ inference_recommendation_id ,
586+ ):
587+ """Get recommendation from model"""
588+ model_recommendation = None
589+ try :
590+ model_res = sage_client .describe_model (ModelName = job_or_model_name )
591+ if model_res :
592+ model_recommendation = self ._search_recommendation (
593+ recommendation_list = model_res ["DeploymentRecommendation" ][
594+ "RealTimeInferenceRecommendations"
595+ ],
596+ inference_recommendation_id = inference_recommendation_id ,
597+ )
598+ except sage_client .exceptions .ResourceNotFound :
599+ pass
600+
601+ return model_recommendation
602+
603+ def _search_recommendation (self , recommendation_list , inference_recommendation_id ):
604+ """Search recommendation based on recommendation id"""
605+ return next (
606+ (
607+ rec
608+ for rec in recommendation_list
609+ if rec ["RecommendationId" ] == inference_recommendation_id
610+ ),
611+ None ,
612+ )
0 commit comments