diff --git a/pytorch_forecasting/metrics/point.py b/pytorch_forecasting/metrics/point.py index f44ff1aa4..4ad731bf4 100644 --- a/pytorch_forecasting/metrics/point.py +++ b/pytorch_forecasting/metrics/point.py @@ -223,7 +223,7 @@ def update( # weight samples if weight is not None: - losses = losses * weight.unsqueeze(-1) + losses = losses * weight self._update_losses_and_lengths(losses, lengths)