3636 _set_serve_properties ,
3737 _get_admissible_tensor_parallel_degrees ,
3838 _get_admissible_dtypes ,
39+ _get_default_tensor_parallel_degree ,
40+ )
41+ from sagemaker .serve .utils .local_hardware import (
42+ _get_nb_instance ,
43+ _get_ram_usage_mb ,
44+ _get_gpu_info ,
45+ _get_gpu_info_fallback ,
3946)
40- from sagemaker .serve .utils .local_hardware import _get_nb_instance , _get_ram_usage_mb
4147from sagemaker .serve .model_server .djl_serving .prepare import (
4248 prepare_for_djl_serving ,
4349 _create_dir_structure ,
@@ -164,13 +170,6 @@ def _create_djl_model(self) -> Type[Model]:
164170 @_capture_telemetry ("djl.deploy" )
165171 def _djl_model_builder_deploy_wrapper (self , * args , ** kwargs ) -> Type [PredictorBase ]:
166172 """Placeholder docstring"""
167- prepare_for_djl_serving (
168- model_path = self .model_path ,
169- model = self .pysdk_model ,
170- dependencies = self .dependencies ,
171- overwrite_props_from_file = self .overwrite_props_from_file ,
172- )
173-
174173 timeout = kwargs .get ("model_data_download_timeout" )
175174 if timeout :
176175 self .env_vars .update ({"MODEL_LOADING_TIMEOUT" : str (timeout )})
@@ -192,6 +191,34 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
192191 else :
193192 raise ValueError ("Mode %s is not supported!" % overwrite_mode )
194193
194+ manual_set_props = None
195+ if self .mode == Mode .SAGEMAKER_ENDPOINT :
196+ if self .nb_instance_type and "instance_type" not in kwargs :
197+ kwargs .update ({"instance_type" : self .nb_instance_type })
198+ elif not self .nb_instance_type and "instance_type" not in kwargs :
199+ raise ValueError (
200+ "Instance type must be provided when deploying " "to SageMaker Endpoint mode."
201+ )
202+ else :
203+ try :
204+ tot_gpus = _get_gpu_info (kwargs .get ("instance_type" ), self .sagemaker_session )
205+ except Exception : # pylint: disable=W0703
206+ tot_gpus = _get_gpu_info_fallback (kwargs .get ("instance_type" ))
207+ default_tensor_parallel_degree = _get_default_tensor_parallel_degree (
208+ self .hf_model_config , tot_gpus
209+ )
210+ manual_set_props = {
211+ "option.tensor_parallel_degree" : str (default_tensor_parallel_degree ) + "\n "
212+ }
213+
214+ prepare_for_djl_serving (
215+ model_path = self .model_path ,
216+ model = self .pysdk_model ,
217+ dependencies = self .dependencies ,
218+ overwrite_props_from_file = self .overwrite_props_from_file ,
219+ manual_set_props = manual_set_props ,
220+ )
221+
195222 serializer = self .schema_builder .input_serializer
196223 deserializer = self .schema_builder ._output_deserializer
197224 if self .mode == Mode .LOCAL_CONTAINER :
@@ -237,8 +264,6 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBa
237264
238265 if "endpoint_logging" not in kwargs :
239266 kwargs ["endpoint_logging" ] = True
240- if self .nb_instance_type and "instance_type" not in kwargs :
241- kwargs .update ({"instance_type" : self .nb_instance_type })
242267
243268 predictor = self ._original_deploy (* args , ** kwargs )
244269
@@ -252,6 +277,7 @@ def _build_for_hf_djl(self):
252277 """Placeholder docstring"""
253278 self .overwrite_props_from_file = True
254279 self .nb_instance_type = _get_nb_instance ()
280+
255281 _create_dir_structure (self .model_path )
256282 self .engine , self .hf_model_config = _auto_detect_engine (
257283 self .model , self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
0 commit comments