88from torch .utils .data import DataLoader
99
1010from mmengine .evaluator import Evaluator
11- from mmengine .logging import print_log
11+ from mmengine .logging import HistoryBuffer , print_log
1212from mmengine .registry import LOOPS
13+ from mmengine .structures import BaseDataElement
14+ from mmengine .utils import is_list_of
1315from .amp import autocast
1416from .base_loop import BaseLoop
1517from .utils import calc_dynamic_intervals
@@ -363,17 +365,26 @@ def __init__(self,
363365 logger = 'current' ,
364366 level = logging .WARNING )
365367 self .fp16 = fp16
368+ self .val_loss : Dict [str , HistoryBuffer ] = dict ()
366369
367370 def run (self ) -> dict :
368371 """Launch validation."""
369372 self .runner .call_hook ('before_val' )
370373 self .runner .call_hook ('before_val_epoch' )
371374 self .runner .model .eval ()
375+
376+ # clear val loss
377+ self .val_loss .clear ()
372378 for idx , data_batch in enumerate (self .dataloader ):
373379 self .run_iter (idx , data_batch )
374380
375381 # compute metrics
376382 metrics = self .evaluator .evaluate (len (self .dataloader .dataset ))
383+
384+ if self .val_loss :
385+ loss_dict = _parse_losses (self .val_loss , 'val' )
386+ metrics .update (loss_dict )
387+
377388 self .runner .call_hook ('after_val_epoch' , metrics = metrics )
378389 self .runner .call_hook ('after_val' )
379390 return metrics
@@ -391,6 +402,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]):
391402 # outputs should be sequence of BaseDataElement
392403 with autocast (enabled = self .fp16 ):
393404 outputs = self .runner .model .val_step (data_batch )
405+
406+ outputs , self .val_loss = _update_losses (outputs , self .val_loss )
407+
394408 self .evaluator .process (data_samples = outputs , data_batch = data_batch )
395409 self .runner .call_hook (
396410 'after_val_iter' ,
@@ -435,17 +449,26 @@ def __init__(self,
435449 logger = 'current' ,
436450 level = logging .WARNING )
437451 self .fp16 = fp16
452+ self .test_loss : Dict [str , HistoryBuffer ] = dict ()
438453
439454 def run (self ) -> dict :
440455 """Launch test."""
441456 self .runner .call_hook ('before_test' )
442457 self .runner .call_hook ('before_test_epoch' )
443458 self .runner .model .eval ()
459+
460+ # clear test loss
461+ self .test_loss .clear ()
444462 for idx , data_batch in enumerate (self .dataloader ):
445463 self .run_iter (idx , data_batch )
446464
447465 # compute metrics
448466 metrics = self .evaluator .evaluate (len (self .dataloader .dataset ))
467+
468+ if self .test_loss :
469+ loss_dict = _parse_losses (self .test_loss , 'test' )
470+ metrics .update (loss_dict )
471+
449472 self .runner .call_hook ('after_test_epoch' , metrics = metrics )
450473 self .runner .call_hook ('after_test' )
451474 return metrics
@@ -462,9 +485,66 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None:
462485 # predictions should be sequence of BaseDataElement
463486 with autocast (enabled = self .fp16 ):
464487 outputs = self .runner .model .test_step (data_batch )
488+
489+ outputs , self .test_loss = _update_losses (outputs , self .test_loss )
490+
465491 self .evaluator .process (data_samples = outputs , data_batch = data_batch )
466492 self .runner .call_hook (
467493 'after_test_iter' ,
468494 batch_idx = idx ,
469495 data_batch = data_batch ,
470496 outputs = outputs )
497+
498+
499+ def _parse_losses (losses : Dict [str , HistoryBuffer ],
500+ stage : str ) -> Dict [str , float ]:
501+ """Parses the raw losses of the network.
502+
503+ Args:
504+ losses (dict): raw losses of the network.
505+ stage (str): The stage of loss, e.g., 'val' or 'test'.
506+
507+ Returns:
508+ dict[str, float]: The key is the loss name, and the value is the
509+ average loss.
510+ """
511+ all_loss = 0
512+ loss_dict : Dict [str , float ] = dict ()
513+
514+ for loss_name , loss_value in losses .items ():
515+ avg_loss = loss_value .mean ()
516+ loss_dict [loss_name ] = avg_loss
517+ if 'loss' in loss_name :
518+ all_loss += avg_loss
519+
520+ loss_dict [f'{ stage } _loss' ] = all_loss
521+ return loss_dict
522+
523+
524+ def _update_losses (outputs : list , losses : dict ) -> Tuple [list , dict ]:
525+ """Update and record the losses of the network.
526+
527+ Args:
528+ outputs (list): The outputs of the network.
529+ losses (dict): The losses of the network.
530+
531+ Returns:
532+ list: The updated outputs of the network.
533+ dict: The updated losses of the network.
534+ """
535+ if isinstance (outputs [- 1 ],
536+ BaseDataElement ) and outputs [- 1 ].keys () == ['loss' ]:
537+ loss = outputs [- 1 ].loss # type: ignore
538+ outputs = outputs [:- 1 ]
539+ else :
540+ loss = dict ()
541+
542+ for loss_name , loss_value in loss .items ():
543+ if loss_name not in losses :
544+ losses [loss_name ] = HistoryBuffer ()
545+ if isinstance (loss_value , torch .Tensor ):
546+ losses [loss_name ].update (loss_value .item ())
547+ elif is_list_of (loss_value , torch .Tensor ):
548+ for loss_value_i in loss_value :
549+ losses [loss_name ].update (loss_value_i .item ())
550+ return outputs , losses
0 commit comments