2626
2727
2828@pytest .fixture ()
29- def sagemaker_session (describe_training_result = None , list_training_results = None , metric_stats_results = None ,
30- describe_tuning_result = None ):
29+ def sagemaker_session ():
30+ return create_sagemaker_session ()
31+
32+
33+ def create_sagemaker_session (describe_training_result = None , list_training_results = None , metric_stats_results = None ,
34+ describe_tuning_result = None ):
3135 boto_mock = Mock (name = 'boto_session' , region_name = REGION )
3236 sms = Mock (name = 'sagemaker_session' , boto_session = boto_mock ,
3337 boto_region_name = REGION , config = None , local_mode = False )
@@ -77,7 +81,7 @@ def mock_summary(name="job-name", value=0.9):
7781 "layers" : 137 ,
7882 },
7983 }
80- session = sagemaker_session (list_training_results = {
84+ session = create_sagemaker_session (list_training_results = {
8185 "TrainingJobSummaries" : [
8286 mock_summary (),
8387 mock_summary (),
@@ -116,7 +120,7 @@ def mock_summary(name="job-name", value=0.9):
116120
117121
118122def test_description ():
119- session = sagemaker_session (describe_tuning_result = {
123+ session = create_sagemaker_session (describe_tuning_result = {
120124 'HyperParameterTuningJobConfig' : {
121125 'ParameterRanges' : {
122126 'CategoricalParameterRanges' : [],
@@ -155,7 +159,7 @@ def test_trainer_name():
155159 'TrainingStartTime' : datetime .datetime (2018 , 5 , 16 , 1 , 2 , 3 ),
156160 'TrainingEndTime' : datetime .datetime (2018 , 5 , 16 , 5 , 6 , 7 ),
157161 }
158- session = sagemaker_session (describe_training_result )
162+ session = create_sagemaker_session (describe_training_result )
159163 trainer = TrainingJobAnalytics ("my-training-job" , ["metric" ], sagemaker_session = session )
160164 assert trainer .name == "my-training-job"
161165 assert str (trainer ).find ("my-training-job" ) != - 1
@@ -182,8 +186,8 @@ def test_trainer_dataframe():
182186 },
183187 ]
184188 }
185- session = sagemaker_session (describe_training_result = describe_training_result ,
186- metric_stats_results = metric_stats_results )
189+ session = create_sagemaker_session (describe_training_result = describe_training_result ,
190+ metric_stats_results = metric_stats_results )
187191 trainer = TrainingJobAnalytics ("my-training-job" , ["train:acc" ], sagemaker_session = session )
188192
189193 df = trainer .dataframe ()
0 commit comments