1616
1717import sagemaker
1818
19- from sagemaker . local import LocalSession
20- from sagemaker . fw_utils import tar_and_upload_dir , parse_s3_url , model_code_key_prefix
21- from sagemaker . session import Session
22- from sagemaker . utils import name_from_image , get_config_value
19+ from sagemaker import local
20+ from sagemaker import fw_utils
21+ from sagemaker import session
22+ from sagemaker import utils
2323
2424
2525class Model (object ):
@@ -96,12 +96,12 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags
9696 """
9797 if not self .sagemaker_session :
9898 if instance_type in ('local' , 'local_gpu' ):
99- self .sagemaker_session = LocalSession ()
99+ self .sagemaker_session = local . LocalSession ()
100100 else :
101- self .sagemaker_session = Session ()
101+ self .sagemaker_session = session . Session ()
102102
103103 container_def = self .prepare_container_def (instance_type )
104- self .name = self .name or name_from_image (container_def ['Image' ])
104+ self .name = self .name or utils . name_from_image (container_def ['Image' ])
105105 self .sagemaker_session .create_model (self .name , self .role , container_def , vpc_config = self .vpc_config )
106106 production_variant = sagemaker .production_variant (self .name , instance_type , initial_instance_count )
107107 self .endpoint_name = endpoint_name or self .name
@@ -127,7 +127,7 @@ class FrameworkModel(Model):
127127
128128 def __init__ (self , model_data , image , role , entry_point , source_dir = None , predictor_cls = None , env = None , name = None ,
129129 enable_cloudwatch_metrics = False , container_log_level = logging .INFO , code_location = None ,
130- sagemaker_session = None , ** kwargs ):
130+ sagemaker_session = None , dependencies = None , ** kwargs ):
131131 """Initialize a ``FrameworkModel``.
132132
133133 Args:
@@ -140,6 +140,23 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
140140 source code dependencies aside from tne entry point file (default: None). Structure within this
141141 directory will be preserved when training on SageMaker.
142142 If the directory points to S3, no code will be uploaded and the S3 location will be used instead.
143+ dependencies (list[str]): A list of paths to directories (absolute or relative) with
144+ any additional libraries that will be exported to the container (default: []).
145+ The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
146+ If the ```source_dir``` points to S3, code will be uploaded and the S3 location will be used
147+ instead. Example:
148+
149+ The following call
150+ >>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
151+ results in the following inside the container:
152+
153+ >>> $ ls
154+
155+ >>> opt/ml/code
156+ >>> ├── train.py
157+ >>> ├── common
158+ >>> └── virtual-env
159+
143160 predictor_cls (callable[string, sagemaker.session.Session]): A function to call to create
144161 a predictor (default: None). If not None, ``deploy`` will return the result of invoking
145162 this function on the created endpoint name.
@@ -160,10 +177,11 @@ def __init__(self, model_data, image, role, entry_point, source_dir=None, predic
160177 sagemaker_session = sagemaker_session , ** kwargs )
161178 self .entry_point = entry_point
162179 self .source_dir = source_dir
180+ self .dependencies = dependencies or []
163181 self .enable_cloudwatch_metrics = enable_cloudwatch_metrics
164182 self .container_log_level = container_log_level
165183 if code_location :
166- self .bucket , self .key_prefix = parse_s3_url (code_location )
184+ self .bucket , self .key_prefix = fw_utils . parse_s3_url (code_location )
167185 else :
168186 self .bucket , self .key_prefix = None , None
169187 self .uploaded_code = None
@@ -179,22 +197,24 @@ def prepare_container_def(self, instance_type): # pylint disable=unused-argumen
179197 Returns:
180198 dict[str, str]: A container definition object usable with the CreateModel API.
181199 """
182- deploy_key_prefix = model_code_key_prefix (self .key_prefix , self .name , self .image )
200+ deploy_key_prefix = fw_utils . model_code_key_prefix (self .key_prefix , self .name , self .image )
183201 self ._upload_code (deploy_key_prefix )
184202 deploy_env = dict (self .env )
185203 deploy_env .update (self ._framework_env_vars ())
186204 return sagemaker .container_def (self .image , self .model_data , deploy_env )
187205
188206 def _upload_code (self , key_prefix ):
189- local_code = get_config_value ('local.local_code' , self .sagemaker_session .config )
207+ local_code = utils . get_config_value ('local.local_code' , self .sagemaker_session .config )
190208 if self .sagemaker_session .local_mode and local_code :
191209 self .uploaded_code = None
192210 else :
193- self .uploaded_code = tar_and_upload_dir (session = self .sagemaker_session .boto_session ,
194- bucket = self .bucket or self .sagemaker_session .default_bucket (),
195- s3_key_prefix = key_prefix ,
196- script = self .entry_point ,
197- directory = self .source_dir )
211+ bucket = self .bucket or self .sagemaker_session .default_bucket ()
212+ self .uploaded_code = fw_utils .tar_and_upload_dir (session = self .sagemaker_session .boto_session ,
213+ bucket = bucket ,
214+ s3_key_prefix = key_prefix ,
215+ script = self .entry_point ,
216+ directory = self .source_dir ,
217+ dependencies = self .dependencies )
198218
199219 def _framework_env_vars (self ):
200220 if self .uploaded_code :
0 commit comments