@@ -22,10 +22,9 @@ class MeanAveragePrecision(_BasePrecisionRecall):
2222 def __init__ (
2323 self ,
2424 rec_thresholds : Optional [Union [Sequence [float ], torch .Tensor ]] = None ,
25- average_operand : Optional [Literal ["precision" , "max-precision" ]] = "precision" ,
25+ average : Optional [Literal ["precision" , "max-precision" ]] = "precision" ,
2626 class_mean : Optional [Literal ["micro" , "macro" , "weighted" , "with_other_dims" ]] = "macro" ,
2727 classification_is_multilabel : bool = False ,
28- allow_multiple_recalls_at_single_threshold : bool = False ,
2928 output_transform : Callable = lambda x : x ,
3029 device : Union [str , torch .device ] = torch .device ("cpu" ),
3130 ) -> None :
@@ -38,22 +37,23 @@ def __init__(
3837 Mean average precision is the computed by taking the mean of this average precision over different classes
3938 and possibly some additional dimensions in the detection task.
4039
41- For detection tasks user must subclass this metric and implement its :meth:`do_matching`
42- method to provide the metric with desired matching logic. Then this method is called internally in
43- :meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass and
44- multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true.
40+ For detection tasks user should use downstream metrics like
41+ :class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP` or subclass this metric and implement
42+ its :meth:`_do_matching` method to provide the metric with desired matching logic. Then this method is called
43+ internally in :meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass
44+ and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true.
4545
4646 `mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean``
4747 determines how to take this mean. In the detection tasks, it's possible to take mean of the average precision
4848 in other respects as well e.g. IoU threshold in an object detection task. To this end, average precision
49- corresponding to each value of IoU thresholds should get measured in :meth:`do_matching `. Please refer to
50- :meth:`do_matching ` for more info on this.
49+ corresponding to each value of IoU thresholds should get measured in :meth:`_do_matching `. Please refer to
50+ :meth:`_do_matching ` for more info on this.
5151
5252 Args:
5353 rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision.
5454 It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need
5555 to be sorted. If missing, thresholds are considered automatically using the data.
56- average_operand : one of values "precision" or "max-precision". In the former case, the precision at a
56+ average : one of values "precision" or "max-precision". In the former case, the precision at a
5757 recall threshold is used for that threshold:
5858
5959 .. math::
@@ -62,7 +62,7 @@ def __init__(
6262 :math:`r` stands for recall thresholds and :math:`P` for precision values. :math:`r_0` is set to zero.
6363
6464 In the latter case, the maximum precision across thresholds greater or equal a recall threshold is
65- considered as the summation operand; In other words, the precision peek across lower or equall
65+ considered as the summation operand; In other words, the precision peek across lower or equal
6666 sensivity levels is used for a recall threshold:
6767
6868 .. math::
@@ -108,11 +108,6 @@ def __init__(
108108
109109 classification_is_multilabel: Used in classification task and determines if the data
110110 is multilabel or not. Default False.
111- allow_multiple_recalls_at_single_threshold: When there are predictions with the same scores, it's faster to
112- consider those predictions associated with different thresholds in the course of measuring recall
113- values, but it's not logically correct since those predictions are associated with a single threshold,
114- thus outputing a single recall value. This option is added mainly due to some downstream mAP metrics
115- which allow such a thing in their computation e.g. pycocotools' mAP. Default False.
116111 output_transform: a callable that is used to transform the
117112 :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
118113 form expected by the metric. This can be useful if, for example, you have a multi-output model and
@@ -127,16 +122,14 @@ def __init__(
127122 else :
128123 self .rec_thresholds = None
129124
130- if average_operand not in ("precision" , "max-precision" ):
131- raise ValueError (f"Wrong `average_operand ` parameter, given { average_operand } " )
132- self .average_operand = average_operand
125+ if average not in ("precision" , "max-precision" ):
126+ raise ValueError (f"Wrong `average ` parameter, given { average } " )
127+ self .average = average
133128
134129 if class_mean is not None and class_mean not in ("micro" , "macro" , "weighted" , "with_other_dims" ):
135130 raise ValueError (f"Wrong `class_mean` parameter, given { class_mean } " )
136131 self .class_mean = class_mean
137132
138- self .allow_multiple_recalls_at_single_threshold = allow_multiple_recalls_at_single_threshold
139-
140133 super (_BasePrecisionRecall , self ).__init__ (
141134 output_transform = output_transform , is_multilabel = classification_is_multilabel , device = device
142135 )
@@ -169,7 +162,7 @@ def reset(self) -> None:
169162 Reset method of the metric
170163 """
171164 super (_BasePrecisionRecall , self ).reset ()
172- if self .do_matching .__func__ == MeanAveragePrecision .do_matching : # type: ignore[attr-defined]
165+ if self ._do_matching .__func__ == MeanAveragePrecision ._do_matching : # type: ignore[attr-defined]
173166 self ._task : Literal ["classification" , "detection" ] = "classification"
174167 else :
175168 self ._task = "detection"
@@ -202,7 +195,7 @@ def _check_matching_output_shape(
202195 ) -> None :
203196 if not (tps .keys () == fps .keys () == scores .keys ()):
204197 raise ValueError (
205- "Returned TP, FP and scores dictionaries from do_matching should have"
198+ "Returned TP, FP and scores dictionaries from _do_matching should have"
206199 f" the same keys (classes), given { tps .keys ()} , { fps .keys ()} and { scores .keys ()} "
207200 )
208201 try :
@@ -228,7 +221,7 @@ def _check_matching_output_shape(
228221 else :
229222 if self_tp_or_fp [cls ][- 1 ].shape [:- 1 ] != new_tp_or_fp [cls ].shape [:- 1 ]:
230223 raise ValueError (
231- f"Tensors in returned { name } from do_matching should not change in shape "
224+ f"Tensors in returned { name } from _do_matching should not change in shape "
232225 "except possibly in the last dimension which is the dimension of samples. Given "
233226 f"{ self_tp_or_fp [cls ][- 1 ].shape } and { new_tp_or_fp [cls ].shape } "
234227 )
@@ -278,13 +271,13 @@ def _classification_prepare_output(
278271
279272 return scores , P
280273
281- def do_matching (
274+ def _do_matching (
282275 self , pred : Any , target : Any
283276 ) -> Tuple [Dict [int , torch .Tensor ], Dict [int , torch .Tensor ], Dict [int , int ], Dict [int , torch .Tensor ]]:
284277 r"""
285278 Matching logic holder of the metric for detection tasks.
286279
287- User must implement this method by subclassing the metric. There is no constraint on type and shape of
280+ The developer must implement this method by subclassing the metric. There is no constraint on type and shape of
288281 ``pred`` and ``target``, but the method should return a quadrople of dictionaries containing TP, FP,
289282 P (actual positive) counts and scores for each class respectively. Please note that class numbers start from
290283 zero.
@@ -315,7 +308,8 @@ def do_matching(
315308 `(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores.
316309 """
317310 raise NotImplementedError (
318- "Please subclass MeanAveragePrecision and implement `do_matching` method" " to use the metric in detection."
311+ "Please subclass MeanAveragePrecision and implement `_do_matching` method"
312+ " to use the metric in detection."
319313 )
320314
321315 @reinit__is_reduced
@@ -324,7 +318,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor
324318
325319 Args:
326320 output: a binary tuple. It should consist of prediction and target tensors in the classification case but
327- for detection it is the same as the implemented-by-user :meth:`do_matching `.
321+ for detection it is the same as the implemented-by-user :meth:`_do_matching `.
328322
329323 For classification, this metric follows the same rules on ``output`` members shape as the
330324 :meth:`Precision.update <precision.Precision.update>` except for ``y_pred`` of binary and multilabel
@@ -341,7 +335,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor
341335 P .to (self ._device , dtype = torch .uint8 if self ._type != "multiclass" else torch .long )
342336 )
343337 else :
344- tps , fps , ps , scores_dict = self .do_matching (output [0 ], output [1 ])
338+ tps , fps , ps , scores_dict = self ._do_matching (output [0 ], output [1 ])
345339 self ._check_matching_output_shape (tps , fps , scores_dict )
346340 for cls in tps :
347341 self ._tp [cls ].append (tps [cls ].to (device = self ._device , dtype = torch .uint8 ))
@@ -353,7 +347,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor
353347 if classes :
354348 self ._num_classes = max (max (classes ) + 1 , self ._num_classes )
355349
356- def _measure_recall_and_precision (
350+ def _compute_recall_and_precision (
357351 self , TP : torch .Tensor , FP : Union [torch .Tensor , None ], scores : torch .Tensor , P : torch .Tensor
358352 ) -> Tuple [torch .Tensor , torch .Tensor ]:
359353 r"""Measuring recall & precision which is the common operation among different settings of the metric.
@@ -363,67 +357,55 @@ def _measure_recall_and_precision(
363357 classification task. ``...`` stands for the additional dimensions in the detection task. Finally,
364358 \#unique scores represents number of unique scores in ``scores`` which is actually the number of thresholds.
365359
366- This method is called on a per class basis in the detection task and if
367- ``allow_multiple_recalls_at_single_threshold=False``.
368-
369- =========================== ================================== ===================================
360+ ============== ======================
370361 Detection task
371- --------------------------------------------------------------------------------------------------
372- **Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False``
373- =========================== ================================== ============= ======================
374- TP and FP (..., N\ :sub:`pred`) (..., N\ :sub:`pred`)
375- scores (N\ :sub:`pred`,) (N\ :sub:`pred`,)
376- P () (A single float) () (A single float)
377- recall (..., N\ :sub:`pred`) (..., \#unique scores)
378- precision (..., N\ :sub:`pred`) (..., \#unique scores)
379- =========================== =====================
380-
381- =========================== ================================== ===================================
362+ -------------------------------------
363+ **Object** **Shape**
364+ ============== ======================
365+ TP and FP (..., N\ :sub:`pred`)
366+ scores (N\ :sub:`pred`,)
367+ P () (A single float)
368+ recall (..., \#unique scores)
369+ precision (..., \#unique scores)
370+ ============== = =====================
371+
372+ =================== =======================================
382373 Classification task
383- --------------------------------------------------------------------------------------------------
384- **Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False``
385- =========================== ================================== ===================================
386- TP (C, N\ :sub:`pred`) (N\ :sub:`pred`,)
387- FP (C, N\ :sub:`pred`) None (FP is computed here to be
388- faster)
389- scores (C, N\ :sub:`pred`) (N\ :sub:`pred`,)
390- P (C,) () (A single float)
391- recall (C, N\ :sub:`pred`) (\#unique scores,)
392- precision (C, N\ :sub:`pred`) (\#unique scores,)
393- =========================== ================================== ===================================
374+ -----------------------------------------------------------
375+ **Object** **Shape**
376+ =================== =======================================
377+ TP (N\ :sub:`pred`,)
378+ FP None (FP is computed here to be faster)
379+ scores (N\ :sub:`pred`,)
380+ P () (A single float)
381+ recall (\#unique scores,)
382+ precision (\#unique scores,)
383+ =================== =======================================
394384
395385 Returns:
396386 `(recall, precision)`
397387 """
398388 indices = torch .argsort (scores , dim = - 1 , stable = True , descending = True )
399389 tp = TP .take_along_dim (indices , dim = - 1 ) if self ._task == "classification" else TP [..., indices ]
400390 tp_summation = tp .cumsum (dim = - 1 ).double ()
401- if self ._task == "detection" or self .allow_multiple_recalls_at_single_threshold :
402- fp = (
403- cast (torch .Tensor , FP ).take_along_dim (indices , dim = - 1 )
404- if self ._task == "classification"
405- else cast (torch .Tensor , FP )[..., indices ]
406- )
391+
392+ # Adopted from Scikit-learn's implementation
393+ unique_scores_indices = torch .nonzero (
394+ scores .take_along_dim (indices ).diff (append = (scores .max () + 1 ).unsqueeze (dim = 0 )), as_tuple = True
395+ )[0 ]
396+ tp_summation = tp_summation [..., unique_scores_indices ]
397+ if self ._task == "classification" :
398+ fp_summation = (unique_scores_indices + 1 ) - tp_summation
399+ else :
400+ fp = cast (torch .Tensor , FP )[..., indices ]
407401 fp_summation = fp .cumsum (dim = - 1 ).double ()
408- if not self .allow_multiple_recalls_at_single_threshold :
409- # Adopted from Scikit-learn's implementation
410- unique_scores_indices = torch .nonzero (
411- scores .take_along_dim (indices ).diff (append = (scores .max () + 1 ).unsqueeze (dim = 0 )), as_tuple = True
412- )[0 ]
413- tp_summation = tp_summation [..., unique_scores_indices ]
414- if self ._task == "classification" :
415- fp_summation = (unique_scores_indices + 1 ) - tp_summation
416- else :
417- fp_summation = fp_summation [..., unique_scores_indices ]
402+ fp_summation = fp_summation [..., unique_scores_indices ]
418403
419- if self ._task == "classification" and self .allow_multiple_recalls_at_single_threshold :
420- recall = torch .where (P == 0 , 1 , tp_summation .T / P ).T
421- elif self ._task == "classification" and P == 0 :
404+ if self ._task == "classification" and P == 0 :
422405 recall = torch .ones_like (tp_summation , device = self ._device , dtype = torch .bool )
423406 else :
424407 recall = tp_summation / P
425- # precision = tp_summation / (fp_summation + tp_summation + torch.finfo(torch.double).eps)
426- # or
408+
427409 predicted_positive = tp_summation + fp_summation
428410 precision = tp_summation / torch .where (predicted_positive == 0 , 1 , predicted_positive )
429411 return recall , precision
@@ -440,7 +422,7 @@ def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tens
440422 average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions.
441423 """
442424 precision_integrand = (
443- precision .flip (- 1 ).cummax (dim = - 1 ).values .flip (- 1 ) if self .average_operand == "max-precision" else precision
425+ precision .flip (- 1 ).cummax (dim = - 1 ).values .flip (- 1 ) if self .average == "max-precision" else precision
444426 )
445427 if self .rec_thresholds is not None :
446428 rec_thresholds = self .rec_thresholds .repeat ((* recall .shape [:- 1 ], 1 ))
@@ -549,7 +531,7 @@ def compute(self) -> Union[torch.Tensor, float]:
549531 if TP [cls ].size (- 1 ) == 0 :
550532 average_precisions [cls ] = 0
551533 continue
552- recall , precision = self ._measure_recall_and_precision (TP [cls ], FP [cls ], scores [cls ], P [cls ])
534+ recall , precision = self ._compute_recall_and_precision (TP [cls ], FP [cls ], scores [cls ], P [cls ])
553535 average_precision_for_cls_across_other_dims = self ._measure_average_precision (recall , precision )
554536 if self .class_mean != "with_other_dims" :
555537 average_precisions [cls ] = average_precision_for_cls_across_other_dims .mean ()
@@ -607,7 +589,7 @@ def compute(self) -> Union[torch.Tensor, float]:
607589 )
608590 )
609591 P = P .sum ()
610- recall , precision = self ._measure_recall_and_precision (TP_micro , FP_micro , scores_micro , P )
592+ recall , precision = self ._compute_recall_and_precision (TP_micro , FP_micro , scores_micro , P )
611593 return self ._measure_average_precision (recall , precision ).mean ()
612594 else :
613595 rank_P = (
@@ -644,16 +626,12 @@ def compute(self) -> Union[torch.Tensor, float]:
644626 P = P .reshape (1 , - 1 )
645627 scores_classification = scores_classification .view (1 , - 1 )
646628 P_count = P .sum (dim = - 1 )
647- if self .allow_multiple_recalls_at_single_threshold :
648- recall , precision = self ._measure_recall_and_precision (P , 1 - P , scores_classification , P_count )
649- average_precisions = self ._measure_average_precision (recall , precision )
650- else :
651- average_precisions = torch .zeros_like (P_count , device = self ._device , dtype = torch .double )
652- for cls in range (len (P_count )):
653- recall , precision = self ._measure_recall_and_precision (
654- P [cls ], None , scores_classification [cls ], P_count [cls ]
655- )
656- average_precisions [cls ] = self ._measure_average_precision (recall , precision )
629+ average_precisions = torch .zeros_like (P_count , device = self ._device , dtype = torch .double )
630+ for cls in range (len (P_count )):
631+ recall , precision = self ._compute_recall_and_precision (
632+ P [cls ], None , scores_classification [cls ], P_count [cls ]
633+ )
634+ average_precisions [cls ] = self ._measure_average_precision (recall , precision )
657635 if self ._type == "binary" :
658636 return average_precisions .item ()
659637 if self .class_mean is None :
0 commit comments