@@ -47,12 +47,49 @@ def create_sagemaker_session(describe_training_result=None, list_training_result
4747 cwm_mock = Mock (name = 'cloudwatch_client' )
4848 boto_mock .client = Mock (return_value = cwm_mock )
4949 cwm_mock .get_metric_statistics = Mock (
50- name = 'get_metric_statistics' ,
51- return_value = metric_stats_results ,
50+ name = 'get_metric_statistics'
5251 )
52+ cwm_mock .get_metric_statistics .side_effect = cw_request_side_effect
5353 return sms
5454
5555
56+ def cw_request_side_effect (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
57+ if _is_valid_request (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
58+ return _metric_stats_results ()
59+
60+
61+ def _is_valid_request (Namespace , MetricName , Dimensions , StartTime , EndTime , Period , Statistics ):
62+ could_watch_request = {
63+ 'Namespace' : Namespace ,
64+ 'MetricName' : MetricName ,
65+ 'Dimensions' : Dimensions ,
66+ 'StartTime' : StartTime ,
67+ 'EndTime' : EndTime ,
68+ 'Period' : Period ,
69+ 'Statistics' : Statistics ,
70+ }
71+ print (could_watch_request )
72+ return could_watch_request == cw_request ()
73+
74+
75+ def cw_request ():
76+ describe_training_result = _describe_training_result ()
77+ return {
78+ 'Namespace' : '/aws/sagemaker/TrainingJobs' ,
79+ 'MetricName' : 'train:acc' ,
80+ 'Dimensions' : [
81+ {
82+ 'Name' : 'TrainingJobName' ,
83+ 'Value' : 'my-training-job'
84+ }
85+ ],
86+ 'StartTime' : describe_training_result ['TrainingStartTime' ],
87+ 'EndTime' : describe_training_result ['TrainingEndTime' ] + datetime .timedelta (minutes = 1 ),
88+ 'Period' : 60 ,
89+ 'Statistics' : ['Average' ],
90+ }
91+
92+
5693def test_abstract_base_class ():
5794 # confirm that the abstract base class can't be instantiated directly
5895 with pytest .raises (TypeError ) as _ : # noqa: F841
@@ -165,12 +202,15 @@ def test_trainer_name():
165202 assert str (trainer ).find ("my-training-job" ) != - 1
166203
167204
168- def test_trainer_dataframe ():
169- describe_training_result = {
205+ def _describe_training_result ():
206+ return {
170207 'TrainingStartTime' : datetime .datetime (2018 , 5 , 16 , 1 , 2 , 3 ),
171208 'TrainingEndTime' : datetime .datetime (2018 , 5 , 16 , 5 , 6 , 7 ),
172209 }
173- metric_stats_results = {
210+
211+
212+ def _metric_stats_results ():
213+ return {
174214 'Datapoints' : [
175215 {
176216 'Average' : 77.1 ,
@@ -186,8 +226,11 @@ def test_trainer_dataframe():
186226 },
187227 ]
188228 }
189- session = create_sagemaker_session (describe_training_result = describe_training_result ,
190- metric_stats_results = metric_stats_results )
229+
230+
231+ def test_trainer_dataframe ():
232+ session = create_sagemaker_session (describe_training_result = _describe_training_result (),
233+ metric_stats_results = _metric_stats_results ())
191234 trainer = TrainingJobAnalytics ("my-training-job" , ["train:acc" ], sagemaker_session = session )
192235
193236 df = trainer .dataframe ()
0 commit comments