@@ -135,6 +135,8 @@ class ModelTrainer(BaseModel):
135135 The SageMakerCore session. For convinience, can be imported like:
136136 `from sagemaker.modules import Session`.
137137 If not specified, a new session will be created.
138+ If the default bucket for the artifacts needs to be updated, it can be done by
139+ passing it in the Session object.
138140 role (Optional(str)):
139141 The IAM role ARN for the training job.
140142 If not specified, the default SageMaker execution role will be used.
@@ -173,7 +175,8 @@ class ModelTrainer(BaseModel):
173175 output_data_config (Optional[OutputDataConfig]):
174176 The output data configuration. This is used to specify the output data location
175177 for the training job.
176- If not specified, will default to `s3://<default_bucket>/<base_job_name>/output/`.
178+ If not specified in the session, will default to
179+ `s3://<default_bucket>/<default_prefix>/<base_job_name>/`.
177180 input_data_config (Optional[List[Union[Channel, InputData]]]):
178181 The input data config for the training job.
179182 Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -348,7 +351,7 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
348351 configurable_attribute
349352 )(
350353 ** default_config # pylint: disable=E1134
351- ) # noqa
354+ )
352355 setattr (self , configurable_attribute , default_config )
353356
354357 def __del__ (self ):
@@ -461,7 +464,8 @@ def model_post_init(self, __context: Any):
461464 session = self .sagemaker_session
462465 base_job_name = self .base_job_name
463466 self .output_data_config = OutputDataConfig (
464- s3_output_path = f"s3://{ session .default_bucket ()} /{ base_job_name } " ,
467+ s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
468+ f"/{ base_job_name } " ,
465469 compression_type = "GZIP" ,
466470 kms_key_id = None ,
467471 )
@@ -473,6 +477,12 @@ def model_post_init(self, __context: Any):
473477 if self .training_image :
474478 logger .info (f"Training image URI: { self .training_image } " )
475479
480+ def _fetch_bucket_name_and_prefix (self , session : Session ) -> str :
481+ """Helper function to get the bucket name with the corresponding prefix if applicable"""
482+ if session .default_bucket_prefix is not None :
483+ return f"{ session .default_bucket ()} /{ session .default_bucket_prefix } "
484+ return session .default_bucket ()
485+
476486 @_telemetry_emitter (feature = Feature .MODEL_TRAINER , func_name = "model_trainer.train" )
477487 @validate_call
478488 def train (
@@ -497,12 +507,16 @@ def train(
497507 Defaults to True.
498508 """
499509 self ._populate_intelligent_defaults ()
510+ current_training_job_name = _get_unique_name (self .base_job_name )
511+ input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
500512 if input_data_config :
501513 self .input_data_config = input_data_config
502514
503515 input_data_config = []
504516 if self .input_data_config :
505- input_data_config = self ._get_input_data_config (self .input_data_config )
517+ input_data_config = self ._get_input_data_config (
518+ self .input_data_config , input_data_key_prefix
519+ )
506520
507521 string_hyper_parameters = {}
508522 if self .hyperparameters :
@@ -524,7 +538,9 @@ def train(
524538 # The source code will be mounted at /opt/ml/input/data/sm_code in the container
525539 if self .source_code .source_dir :
526540 source_code_channel = self .create_input_data_channel (
527- SM_CODE , self .source_code .source_dir
541+ channel_name = SM_CODE ,
542+ data_source = self .source_code .source_dir ,
543+ key_prefix = input_data_key_prefix ,
528544 )
529545 input_data_config .append (source_code_channel )
530546
@@ -542,7 +558,11 @@ def train(
542558 self ._write_distributed_json (tmp_dir = drivers_dir , distributed = self .distributed )
543559
544560 # Create an input channel for drivers packaged by the sdk
545- sm_drivers_channel = self .create_input_data_channel (SM_DRIVERS , drivers_dir .name )
561+ sm_drivers_channel = self .create_input_data_channel (
562+ channel_name = SM_DRIVERS ,
563+ data_source = drivers_dir .name ,
564+ key_prefix = input_data_key_prefix ,
565+ )
546566 input_data_config .append (sm_drivers_channel )
547567
548568 # If source_code is provided, we will always use
@@ -567,7 +587,7 @@ def train(
567587
568588 if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB :
569589 training_job = TrainingJob .create (
570- training_job_name = _get_unique_name ( self . base_job_name ) ,
590+ training_job_name = current_training_job_name ,
571591 algorithm_specification = algorithm_specification ,
572592 hyper_parameters = string_hyper_parameters ,
573593 input_data_config = input_data_config ,
@@ -621,14 +641,22 @@ def train(
621641 )
622642 local_container .train (wait )
623643
624- def create_input_data_channel (self , channel_name : str , data_source : DataSourceType ) -> Channel :
644+ def create_input_data_channel (
645+ self , channel_name : str , data_source : DataSourceType , key_prefix : Optional [str ] = None
646+ ) -> Channel :
625647 """Create an input data channel for the training job.
626648
627649 Args:
628650 channel_name (str): The name of the input data channel.
629651 data_source (DataSourceType): The data source for the input data channel.
630652 DataSourceType can be an S3 URI string, local file path string,
631653 S3DataSource object, or FileSystemDataSource object.
654+ key_prefix (Optional[str]): The key prefix to use when uploading data to S3.
655+ Only applicable when data_source is a local file path string.
656+ If not specified, local data will be uploaded to:
657+ s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/
658+ If specified, local data will be uploaded to:
659+ s3://<default_bucket_path>/<key_prefix>/<channel_name>/
632660 """
633661 channel = None
634662 if isinstance (data_source , str ):
@@ -644,6 +672,10 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
644672 ),
645673 input_mode = "File" ,
646674 )
675+ if key_prefix :
676+ logger .warning (
677+ "key_prefix is only applicable when data_source is a local file path."
678+ )
647679 elif _is_valid_path (data_source ):
648680 if self .training_mode == Mode .LOCAL_CONTAINER :
649681 channel = Channel (
@@ -657,10 +689,17 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
657689 input_mode = "File" ,
658690 )
659691 else :
692+ key_prefix = (
693+ f"{ key_prefix } /{ channel_name } "
694+ if key_prefix
695+ else f"{ self .base_job_name } /input/{ channel_name } "
696+ )
697+ if self .sagemaker_session .default_bucket_prefix :
698+ key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
660699 s3_uri = self .sagemaker_session .upload_data (
661700 path = data_source ,
662701 bucket = self .sagemaker_session .default_bucket (),
663- key_prefix = f" { self . base_job_name } /input/ { channel_name } " ,
702+ key_prefix = key_prefix ,
664703 )
665704 channel = Channel (
666705 channel_name = channel_name ,
@@ -687,7 +726,9 @@ def create_input_data_channel(self, channel_name: str, data_source: DataSourceTy
687726 return channel
688727
689728 def _get_input_data_config (
690- self , input_data_channels : Optional [List [Union [Channel , InputData ]]]
729+ self ,
730+ input_data_channels : Optional [List [Union [Channel , InputData ]]],
731+ key_prefix : Optional [str ] = None ,
691732 ) -> List [Channel ]:
692733 """Get the input data configuration for the training job.
693734
@@ -706,7 +747,7 @@ def _get_input_data_config(
706747 channels .append (input_data )
707748 elif isinstance (input_data , InputData ):
708749 channel = self .create_input_data_channel (
709- input_data .channel_name , input_data .data_source
750+ input_data .channel_name , input_data .data_source , key_prefix = key_prefix
710751 )
711752 channels .append (channel )
712753 else :
@@ -850,7 +891,7 @@ def from_recipe(
850891 An array of key-value pairs. You can use tags to categorize your AWS resources
851892 in different ways, for example, by purpose, owner, or environment.
852893 sagemaker_session (Optional[Session]):
853- The SageMaker session.
894+ The SageMakerCore session.
854895 If not specified, a new session will be created.
855896 role (Optional[str]):
856897 The IAM role ARN for the training job.
0 commit comments