File tree Expand file tree Collapse file tree 7 files changed +21
-7
lines changed Expand file tree Collapse file tree 7 files changed +21
-7
lines changed Original file line number Diff line number Diff line change @@ -235,7 +235,7 @@ def __init__(
235235 self .factors_init_sigma = factors_init_sigma
236236 self .factors_init_value = factors_init_value
237237
238- def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT ):
238+ def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT , ** kwargs ):
239239 """Return a :class:`~sagemaker.amazon.FactorizationMachinesModel`
240240 referencing the latest s3 model data produced by this Estimator.
241241
@@ -244,12 +244,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
244244 the model. Default: use subnets and security groups from this Estimator.
245245 * 'Subnets' (list[str]): List of subnet ids.
246246 * 'SecurityGroupIds' (list[str]): List of security group ids.
247+ **kwargs: Additional kwargs passed to the FactorizationMachinesModel constructor.
247248 """
248249 return FactorizationMachinesModel (
249250 self .model_data ,
250251 self .role ,
251252 sagemaker_session = self .sagemaker_session ,
252253 vpc_config = self .get_vpc_config (vpc_config_override ),
254+ ** kwargs
253255 )
254256
255257
Original file line number Diff line number Diff line change @@ -131,7 +131,7 @@ def __init__(
131131 self .shuffled_negative_sampling_rate = shuffled_negative_sampling_rate
132132 self .weight_decay = weight_decay
133133
134- def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT ):
134+ def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT , ** kwargs ):
135135 """Create a model for the latest s3 model produced by this estimator.
136136
137137 Args:
@@ -140,6 +140,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
140140 Default: use subnets and security groups from this Estimator.
141141 * 'Subnets' (list[str]): List of subnet ids.
142142 * 'SecurityGroupIds' (list[str]): List of security group ids.
143+ **kwargs: Additional kwargs passed to the IPInsightsModel constructor.
143144 Returns:
144145 :class:`~sagemaker.amazon.IPInsightsModel`: references the latest s3 model
145146 data produced by this estimator.
@@ -149,6 +150,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
149150 self .role ,
150151 sagemaker_session = self .sagemaker_session ,
151152 vpc_config = self .get_vpc_config (vpc_config_override ),
153+ ** kwargs
152154 )
153155
154156 def _prepare_for_training (self , records , mini_batch_size = None , job_name = None ):
Original file line number Diff line number Diff line change @@ -145,7 +145,7 @@ def __init__(
145145 '"dimension_reduction_target" is required when "dimension_reduction_type" is set.'
146146 )
147147
148- def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT ):
148+ def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT , ** kwargs ):
149149 """Return a :class:`~sagemaker.amazon.KNNModel` referencing the latest
150150 s3 model data produced by this Estimator.
151151
@@ -154,12 +154,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
154154 the model. Default: use subnets and security groups from this Estimator.
155155 * 'Subnets' (list[str]): List of subnet ids.
156156 * 'SecurityGroupIds' (list[str]): List of security group ids.
157+ **kwargs: Additional kwargs passed to the KNNModel constructor.
157158 """
158159 return KNNModel (
159160 self .model_data ,
160161 self .role ,
161162 sagemaker_session = self .sagemaker_session ,
162163 vpc_config = self .get_vpc_config (vpc_config_override ),
164+ ** kwargs
163165 )
164166
165167 def _prepare_for_training (self , records , mini_batch_size = None , job_name = None ):
Original file line number Diff line number Diff line change @@ -155,7 +155,7 @@ def __init__(
155155 self .weight_decay = weight_decay
156156 self .learning_rate = learning_rate
157157
158- def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT ):
158+ def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT , ** kwargs ):
159159 """Return a :class:`~sagemaker.amazon.NTMModel` referencing the latest
160160 s3 model data produced by this Estimator.
161161
@@ -164,12 +164,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
164164 the model. Default: use subnets and security groups from this Estimator.
165165 * 'Subnets' (list[str]): List of subnet ids.
166166 * 'SecurityGroupIds' (list[str]): List of security group ids.
167+ **kwargs: Additional kwargs passed to the NTMModel constructor.
167168 """
168169 return NTMModel (
169170 self .model_data ,
170171 self .role ,
171172 sagemaker_session = self .sagemaker_session ,
172173 vpc_config = self .get_vpc_config (vpc_config_override ),
174+ ** kwargs
173175 )
174176
175177 def _prepare_for_training ( # pylint: disable=signature-differs
Original file line number Diff line number Diff line change @@ -295,7 +295,7 @@ def __init__(
295295 self .enc0_freeze_pretrained_embedding = enc0_freeze_pretrained_embedding
296296 self .enc1_freeze_pretrained_embedding = enc1_freeze_pretrained_embedding
297297
298- def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT ):
298+ def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT , ** kwargs ):
299299 """Return a :class:`~sagemaker.amazon.Object2VecModel` referencing the
300300 latest s3 model data produced by this Estimator.
301301
@@ -304,12 +304,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
304304 the model. Default: use subnets and security groups from this Estimator.
305305 * 'Subnets' (list[str]): List of subnet ids.
306306 * 'SecurityGroupIds' (list[str]): List of security group ids.
307+ **kwargs: Additional kwargs passed to the Object2VecModel constructor.
307308 """
308309 return Object2VecModel (
309310 self .model_data ,
310311 self .role ,
311312 sagemaker_session = self .sagemaker_session ,
312313 vpc_config = self .get_vpc_config (vpc_config_override ),
314+ ** kwargs
313315 )
314316
315317 def _prepare_for_training (self , records , mini_batch_size = None , job_name = None ):
Original file line number Diff line number Diff line change @@ -120,7 +120,7 @@ def __init__(
120120 self .subtract_mean = subtract_mean
121121 self .extra_components = extra_components
122122
123- def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT ):
123+ def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT , ** kwargs ):
124124 """Return a :class:`~sagemaker.amazon.pca.PCAModel` referencing the
125125 latest s3 model data produced by this Estimator.
126126
@@ -129,12 +129,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
129129 the model. Default: use subnets and security groups from this Estimator.
130130 * 'Subnets' (list[str]): List of subnet ids.
131131 * 'SecurityGroupIds' (list[str]): List of security group ids.
132+ **kwargs: Additional kwargs passed to the PCAModel constructor.
132133 """
133134 return PCAModel (
134135 self .model_data ,
135136 self .role ,
136137 sagemaker_session = self .sagemaker_session ,
137138 vpc_config = self .get_vpc_config (vpc_config_override ),
139+ ** kwargs
138140 )
139141
140142 def _prepare_for_training (self , records , mini_batch_size = None , job_name = None ):
Original file line number Diff line number Diff line change @@ -113,7 +113,7 @@ def __init__(
113113 self .num_trees = num_trees
114114 self .eval_metrics = eval_metrics
115115
116- def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT ):
116+ def create_model (self , vpc_config_override = VPC_CONFIG_DEFAULT , ** kwargs ):
117117 """Return a :class:`~sagemaker.amazon.RandomCutForestModel` referencing
118118 the latest s3 model data produced by this Estimator.
119119
@@ -122,12 +122,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
122122 the model. Default: use subnets and security groups from this Estimator.
123123 * 'Subnets' (list[str]): List of subnet ids.
124124 * 'SecurityGroupIds' (list[str]): List of security group ids.
125+ **kwargs: Additional kwargs passed to the RandomCutForestModel constructor.
125126 """
126127 return RandomCutForestModel (
127128 self .model_data ,
128129 self .role ,
129130 sagemaker_session = self .sagemaker_session ,
130131 vpc_config = self .get_vpc_config (vpc_config_override ),
132+ ** kwargs
131133 )
132134
133135 def _prepare_for_training (self , records , mini_batch_size = None , job_name = None ):
You can’t perform that action at this time.
0 commit comments