Skip to content

Commit 74f989b

Browse files
HAOCHENYEMGAMZ
authored andcommitted
[Test] Fix unittest of EMAHook
1 parent e8a92b8 commit 74f989b

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

tests/test_hooks/test_ema_hook.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ def test_with_runner(self):
230230
self.assertTrue(
231231
isinstance(ema_hook.ema_model, ExponentialMovingAverage))
232232

233-
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
233+
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'),
234+
weights_only=False)
234235
self.assertTrue('ema_state_dict' in checkpoint)
235236
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
236237

@@ -245,7 +246,8 @@ def test_with_runner(self):
245246
runner.test()
246247

247248
# Test load checkpoint without ema_state_dict
248-
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
249+
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'),
250+
weights_only=False)
249251
checkpoint.pop('ema_state_dict')
250252
torch.save(checkpoint,
251253
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
@@ -274,7 +276,8 @@ def test_with_runner(self):
274276
runner = self.build_runner(cfg)
275277
runner.train()
276278
state_dict = torch.load(osp.join(self.temp_dir.name, 'epoch_4.pth'),
277-
map_location='cpu')
279+
map_location='cpu',
280+
weights_only=False)
278281
self.assertIn('ema_state_dict', state_dict)
279282
for k, v in state_dict['state_dict'].items():
280283
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
@@ -287,12 +290,14 @@ def test_with_runner(self):
287290
runner = self.build_runner(cfg)
288291
runner.train()
289292
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_4.pth'),
290-
map_location='cpu')
293+
map_location='cpu',
294+
weights_only=False)
291295
self.assertIn('ema_state_dict', state_dict)
292296
for k, v in state_dict['state_dict'].items():
293297
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
294298
state_dict = torch.load(osp.join(self.temp_dir.name, 'iter_5.pth'),
295-
map_location='cpu')
299+
map_location='cpu',
300+
weights_only=False)
296301
self.assertIn('ema_state_dict', state_dict)
297302

298303
def _test_swap_parameters(self, func_name, *args, **kwargs):

0 commit comments

Comments
 (0)