Skip to content

Commit d1f1aab

Browse files
authored
[Feature] Support calculating loss during validation (#1503)
1 parent 66fb81f commit d1f1aab

File tree

1 file changed

+81
-1
lines changed

1 file changed

+81
-1
lines changed

mmengine/runner/loops.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from torch.utils.data import DataLoader
99

1010
from mmengine.evaluator import Evaluator
11-
from mmengine.logging import print_log
11+
from mmengine.logging import HistoryBuffer, print_log
1212
from mmengine.registry import LOOPS
13+
from mmengine.structures import BaseDataElement
14+
from mmengine.utils import is_list_of
1315
from .amp import autocast
1416
from .base_loop import BaseLoop
1517
from .utils import calc_dynamic_intervals
@@ -363,17 +365,26 @@ def __init__(self,
363365
logger='current',
364366
level=logging.WARNING)
365367
self.fp16 = fp16
368+
self.val_loss: Dict[str, HistoryBuffer] = dict()
366369

367370
def run(self) -> dict:
368371
"""Launch validation."""
369372
self.runner.call_hook('before_val')
370373
self.runner.call_hook('before_val_epoch')
371374
self.runner.model.eval()
375+
376+
# clear val loss
377+
self.val_loss.clear()
372378
for idx, data_batch in enumerate(self.dataloader):
373379
self.run_iter(idx, data_batch)
374380

375381
# compute metrics
376382
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
383+
384+
if self.val_loss:
385+
loss_dict = _parse_losses(self.val_loss, 'val')
386+
metrics.update(loss_dict)
387+
377388
self.runner.call_hook('after_val_epoch', metrics=metrics)
378389
self.runner.call_hook('after_val')
379390
return metrics
@@ -391,6 +402,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
391402
# outputs should be sequence of BaseDataElement
392403
with autocast(enabled=self.fp16):
393404
outputs = self.runner.model.val_step(data_batch)
405+
406+
outputs, self.val_loss = _update_losses(outputs, self.val_loss)
407+
394408
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
395409
self.runner.call_hook(
396410
'after_val_iter',
@@ -435,17 +449,26 @@ def __init__(self,
435449
logger='current',
436450
level=logging.WARNING)
437451
self.fp16 = fp16
452+
self.test_loss: Dict[str, HistoryBuffer] = dict()
438453

439454
def run(self) -> dict:
440455
"""Launch test."""
441456
self.runner.call_hook('before_test')
442457
self.runner.call_hook('before_test_epoch')
443458
self.runner.model.eval()
459+
460+
# clear test loss
461+
self.test_loss.clear()
444462
for idx, data_batch in enumerate(self.dataloader):
445463
self.run_iter(idx, data_batch)
446464

447465
# compute metrics
448466
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
467+
468+
if self.test_loss:
469+
loss_dict = _parse_losses(self.test_loss, 'test')
470+
metrics.update(loss_dict)
471+
449472
self.runner.call_hook('after_test_epoch', metrics=metrics)
450473
self.runner.call_hook('after_test')
451474
return metrics
@@ -462,9 +485,66 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
462485
# predictions should be sequence of BaseDataElement
463486
with autocast(enabled=self.fp16):
464487
outputs = self.runner.model.test_step(data_batch)
488+
489+
outputs, self.test_loss = _update_losses(outputs, self.test_loss)
490+
465491
self.evaluator.process(data_samples=outputs, data_batch=data_batch)
466492
self.runner.call_hook(
467493
'after_test_iter',
468494
batch_idx=idx,
469495
data_batch=data_batch,
470496
outputs=outputs)
497+
498+
499+
def _parse_losses(losses: Dict[str, HistoryBuffer],
500+
stage: str) -> Dict[str, float]:
501+
"""Parses the raw losses of the network.
502+
503+
Args:
504+
losses (dict): raw losses of the network.
505+
stage (str): The stage of loss, e.g., 'val' or 'test'.
506+
507+
Returns:
508+
dict[str, float]: The key is the loss name, and the value is the
509+
average loss.
510+
"""
511+
all_loss = 0
512+
loss_dict: Dict[str, float] = dict()
513+
514+
for loss_name, loss_value in losses.items():
515+
avg_loss = loss_value.mean()
516+
loss_dict[loss_name] = avg_loss
517+
if 'loss' in loss_name:
518+
all_loss += avg_loss
519+
520+
loss_dict[f'{stage}_loss'] = all_loss
521+
return loss_dict
522+
523+
524+
def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]:
525+
"""Update and record the losses of the network.
526+
527+
Args:
528+
outputs (list): The outputs of the network.
529+
losses (dict): The losses of the network.
530+
531+
Returns:
532+
list: The updated outputs of the network.
533+
dict: The updated losses of the network.
534+
"""
535+
if isinstance(outputs[-1],
536+
BaseDataElement) and outputs[-1].keys() == ['loss']:
537+
loss = outputs[-1].loss # type: ignore
538+
outputs = outputs[:-1]
539+
else:
540+
loss = dict()
541+
542+
for loss_name, loss_value in loss.items():
543+
if loss_name not in losses:
544+
losses[loss_name] = HistoryBuffer()
545+
if isinstance(loss_value, torch.Tensor):
546+
losses[loss_name].update(loss_value.item())
547+
elif is_list_of(loss_value, torch.Tensor):
548+
for loss_value_i in loss_value:
549+
losses[loss_name].update(loss_value_i.item())
550+
return outputs, losses

0 commit comments

Comments
 (0)