diff --git a/evaluate.py b/evaluate.py index 9a4e3ba2b5..d4c1f2826a 100644 --- a/evaluate.py +++ b/evaluate.py @@ -1,7 +1,6 @@ import torch import torch.nn.functional as F from tqdm import tqdm - from utils.dice_score import multiclass_dice_coeff, dice_coeff @@ -9,7 +8,7 @@ def evaluate(net, dataloader, device, amp): net.eval() num_val_batches = len(dataloader) - dice_score = 0 + dice_score = 0.0 # iterate over the validation set with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp): @@ -24,17 +23,17 @@ def evaluate(net, dataloader, device, amp): mask_pred = net(image) if net.n_classes == 1: + # binary segmentation assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]' - mask_pred = (F.sigmoid(mask_pred) > 0.5).float() - # compute the Dice score + mask_pred = (torch.sigmoid(mask_pred) > 0.5).float() dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False) else: + # multi-class segmentation: compute Dice directly on class indices assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes[' - # convert to one-hot format - mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float() - mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float() - # compute the Dice score, ignoring background - dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False) + mask_pred_indices = mask_pred.argmax(dim=1) + # ignore background class (0) if desired + dice_score += multiclass_dice_coeff(mask_pred_indices, mask_true, ignore_index=0) net.train() return dice_score / max(num_val_batches, 1) +