@@ -230,7 +230,8 @@ def test_with_runner(self):
230230 self .assertTrue (
231231 isinstance (ema_hook .ema_model , ExponentialMovingAverage ))
232232
233- checkpoint = torch .load (osp .join (self .temp_dir .name , 'epoch_2.pth' ))
233+ checkpoint = torch .load (osp .join (self .temp_dir .name , 'epoch_2.pth' ),
234+ weights_only = False )
234235 self .assertTrue ('ema_state_dict' in checkpoint )
235236 self .assertTrue (checkpoint ['ema_state_dict' ]['steps' ] == 8 )
236237
@@ -245,7 +246,8 @@ def test_with_runner(self):
245246 runner .test ()
246247
247248 # Test load checkpoint without ema_state_dict
248- checkpoint = torch .load (osp .join (self .temp_dir .name , 'epoch_2.pth' ))
249+ checkpoint = torch .load (osp .join (self .temp_dir .name , 'epoch_2.pth' ),
250+ weights_only = False )
249251 checkpoint .pop ('ema_state_dict' )
250252 torch .save (checkpoint ,
251253 osp .join (self .temp_dir .name , 'without_ema_state_dict.pth' ))
@@ -274,7 +276,8 @@ def test_with_runner(self):
274276 runner = self .build_runner (cfg )
275277 runner .train ()
276278 state_dict = torch .load (osp .join (self .temp_dir .name , 'epoch_4.pth' ),
277- map_location = 'cpu' )
279+ map_location = 'cpu' ,
280+ weights_only = False )
278281 self .assertIn ('ema_state_dict' , state_dict )
279282 for k , v in state_dict ['state_dict' ].items ():
280283 assert_allclose (v , state_dict ['ema_state_dict' ]['module.' + k ])
@@ -287,12 +290,14 @@ def test_with_runner(self):
287290 runner = self .build_runner (cfg )
288291 runner .train ()
289292 state_dict = torch .load (osp .join (self .temp_dir .name , 'iter_4.pth' ),
290- map_location = 'cpu' )
293+ map_location = 'cpu' ,
294+ weights_only = False )
291295 self .assertIn ('ema_state_dict' , state_dict )
292296 for k , v in state_dict ['state_dict' ].items ():
293297 assert_allclose (v , state_dict ['ema_state_dict' ]['module.' + k ])
294298 state_dict = torch .load (osp .join (self .temp_dir .name , 'iter_5.pth' ),
295- map_location = 'cpu' )
299+ map_location = 'cpu' ,
300+ weights_only = False )
296301 self .assertIn ('ema_state_dict' , state_dict )
297302
298303 def _test_swap_parameters (self , func_name , * args , ** kwargs ):
0 commit comments