@@ -53,6 +53,7 @@ def __init__(
5353 self .smooth = smooth
5454 self .eps = eps
5555 self .log_loss = log_loss
56+ self .ignore_index = ignore_index
5657
5758 def forward (self , y_pred : torch .Tensor , y_true : torch .Tensor ) -> torch .Tensor :
5859
@@ -75,17 +76,34 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7576 y_true = y_true .view (bs , 1 , - 1 )
7677 y_pred = y_pred .view (bs , 1 , - 1 )
7778
79+ if self .ignore_index is not None :
80+ mask = y_true != self .ignore_index
81+ y_pred = y_pred * mask
82+ y_true = y_true * mask
83+
7884 if self .mode == MULTICLASS_MODE :
7985 y_true = y_true .view (bs , - 1 )
8086 y_pred = y_pred .view (bs , num_classes , - 1 )
8187
82- y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
83- y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
88+ if self .ignore_index is not None :
89+ mask = y_true != self .ignore_index
90+ y_pred = y_pred * mask .unsqueeze (1 )
91+
92+ y_true = F .one_hot ((y_true * mask ).to (torch .long ), num_classes ) # N,H*W -> N,H*W, C
93+ y_true = y_true .permute (0 , 2 , 1 ) * mask .unsqueeze (1 ) # H, C, H*W
94+ else :
95+ y_true = F .one_hot (y_true , num_classes ) # N,H*W -> N,H*W, C
96+ y_true = y_true .permute (0 , 2 , 1 ) # H, C, H*W
8497
8598 if self .mode == MULTILABEL_MODE :
8699 y_true = y_true .view (bs , num_classes , - 1 )
87100 y_pred = y_pred .view (bs , num_classes , - 1 )
88101
102+ if self .ignore_index is not None :
103+ mask = y_true != self .ignore_index
104+ y_pred = y_pred * mask
105+ y_true = y_true * mask
106+
89107 scores = soft_dice_score (y_pred , y_true .type_as (y_pred ), smooth = self .smooth , eps = self .eps , dims = dims )
90108
91109 if self .log_loss :
@@ -104,4 +122,4 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
104122 if self .classes is not None :
105123 loss = loss [self .classes ]
106124
107- return loss .mean ()
125+ return loss .mean ()
0 commit comments