@@ -99,7 +99,9 @@ def get_stats(
9999 threshold (Optional[float, List[float]]): Binarization threshold for
100100 ``output`` in case of ``'binary'`` or ``'multilabel'`` modes. Defaults to None.
101101 num_classes (Optional[int]): Number of classes, necessary attribute
102- only for ``'multiclass'`` mode.
102+ only for ``'multiclass'`` mode. Class values should be in range 0..(num_classes - 1).
103+ If ``ignore_index`` is specified it should be outside the classes range, e.g. ``-1`` or
104+ ``255``.
103105
104106 Raises:
105107 ValueError: in case of misconfiguration.
@@ -139,12 +141,16 @@ def get_stats(
139141 if mode == "multiclass" and num_classes is None :
140142 raise ValueError ("``num_classes`` attribute should be not ``None`` for 'multiclass' mode." )
141143
144+ if ignore_index is not None and 0 <= ignore_index <= num_classes - 1 :
145+ raise ValueError (
146+ f"``ignore_index`` should be outside the class values range, but got class values in range "
147+ f"0..{ num_classes - 1 } and ``ignore_index={ ignore_index } ``. Hint: if you have ``ignore_index = 0``"
148+ f"consirder subtracting ``1`` from your target and model output to make ``ignore_index = -1``"
149+ f"and relevant class values started from ``0``."
150+ )
151+
142152 if mode == "multiclass" :
143- if ignore_index is not None :
144- ignore = target == ignore_index
145- output = torch .where (ignore , - 1 , output )
146- target = torch .where (ignore , - 1 , target )
147- tp , fp , fn , tn = _get_stats_multiclass (output , target , num_classes )
153+ tp , fp , fn , tn = _get_stats_multiclass (output , target , num_classes , ignore_index )
148154 else :
149155 if threshold is not None :
150156 output = torch .where (output >= threshold , 1 , 0 )
@@ -159,11 +165,18 @@ def _get_stats_multiclass(
159165 output : torch .LongTensor ,
160166 target : torch .LongTensor ,
161167 num_classes : int ,
168+ ignore_index : Optional [int ],
162169) -> Tuple [torch .LongTensor , torch .LongTensor , torch .LongTensor , torch .LongTensor ]:
163170
164171 batch_size , * dims = output .shape
165172 num_elements = torch .prod (torch .tensor (dims )).long ()
166173
174+ if ignore_index is not None :
175+ ignore = target == ignore_index
176+ output = torch .where (ignore , - 1 , output )
177+ target = torch .where (ignore , - 1 , target )
178+ ignore_per_sample = ignore .view (batch_size , - 1 ).sum (1 )
179+
167180 tp_count = torch .zeros (batch_size , num_classes , dtype = torch .long )
168181 fp_count = torch .zeros (batch_size , num_classes , dtype = torch .long )
169182 fn_count = torch .zeros (batch_size , num_classes , dtype = torch .long )
@@ -178,6 +191,8 @@ def _get_stats_multiclass(
178191 fp = torch .histc (output_i .float (), bins = num_classes , min = 0 , max = num_classes - 1 ) - tp
179192 fn = torch .histc (target_i .float (), bins = num_classes , min = 0 , max = num_classes - 1 ) - tp
180193 tn = num_elements - tp - fp - fn
194+ if ignore_index is not None :
195+ tn = tn - ignore_per_sample [i ]
181196 tp_count [i ] = tp .long ()
182197 fp_count [i ] = fp .long ()
183198 fn_count [i ] = fn .long ()
0 commit comments