diff --git a/pose/losses/jointsmseloss.py b/pose/losses/jointsmseloss.py index ade1183..0eeb9de 100644 --- a/pose/losses/jointsmseloss.py +++ b/pose/losses/jointsmseloss.py @@ -15,7 +15,7 @@ class JointsMSELoss(nn.Module): def __init__(self, use_target_weight=True): super(JointsMSELoss, self).__init__() - self.criterion = nn.MSELoss(reduction='mean') + self.criterion = nn.MSELoss(reduction='elementwise_mean') self.use_target_weight = use_target_weight def forward(self, output, target, target_weight):