From 9b96668f977bef1beb5b2a6610770559b3b0ccef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=A4=86=E0=A4=B2=E0=A5=8B=E0=A4=95?= Date: Thu, 25 Jun 2020 17:58:02 +0530 Subject: [PATCH] fix bug in call to BCEwithlogits the keyword had a typo: 'weight' instead of 'weights'. in case weights requires grads, it needs to be detached first as BCEwithlogits is not implemented for grad tracked tensor --- class_balanced_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/class_balanced_loss.py b/class_balanced_loss.py index 179274d..c241f3a 100644 --- a/class_balanced_loss.py +++ b/class_balanced_loss.py @@ -84,7 +84,8 @@ def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gam if loss_type == "focal": cb_loss = focal_loss(labels_one_hot, logits, weights, gamma) elif loss_type == "sigmoid": - cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights) + cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, + weight=weights.detach() if weights.requires_grad else weights) elif loss_type == "softmax": pred = logits.softmax(dim = 1) cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)