Skip to content

Commit ad8b15a

Browse files
committed
[Test] Fix unittest of EMAHook
1 parent 19f2489 commit ad8b15a

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

tests/test_hooks/test_ema_hook.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ def test_with_runner(self):
230230
self.assertTrue(
231231
isinstance(ema_hook.ema_model, ExponentialMovingAverage))
232232

233-
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
233+
checkpoint = torch.load(
234+
osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False)
234235
self.assertTrue('ema_state_dict' in checkpoint)
235236
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
236237

@@ -245,7 +246,8 @@ def test_with_runner(self):
245246
runner.test()
246247

247248
# Test load checkpoint without ema_state_dict
248-
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
249+
checkpoint = torch.load(
250+
osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only=False)
249251
checkpoint.pop('ema_state_dict')
250252
torch.save(checkpoint,
251253
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
@@ -274,7 +276,9 @@ def test_with_runner(self):
274276
runner = self.build_runner(cfg)
275277
runner.train()
276278
state_dict = torch.load(
277-
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
279+
osp.join(self.temp_dir.name, 'epoch_4.pth'),
280+
map_location='cpu',
281+
weights_only=False)
278282
self.assertIn('ema_state_dict', state_dict)
279283
for k, v in state_dict['state_dict'].items():
280284
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
@@ -287,12 +291,16 @@ def test_with_runner(self):
287291
runner = self.build_runner(cfg)
288292
runner.train()
289293
state_dict = torch.load(
290-
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
294+
osp.join(self.temp_dir.name, 'iter_4.pth'),
295+
map_location='cpu',
296+
weights_only=False)
291297
self.assertIn('ema_state_dict', state_dict)
292298
for k, v in state_dict['state_dict'].items():
293299
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
294300
state_dict = torch.load(
295-
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
301+
osp.join(self.temp_dir.name, 'iter_5.pth'),
302+
map_location='cpu',
303+
weights_only=False)
296304
self.assertIn('ema_state_dict', state_dict)
297305

298306
def _test_swap_parameters(self, func_name, *args, **kwargs):

0 commit comments

Comments
 (0)