Skip to content

Commit 24fe980

Browse files
Some improvements
Removed allow_multiple... Renamed average_operand Renamed _measure_recall... to _compute_recall...
1 parent 34a1a3f commit 24fe980

File tree

4 files changed

+173
-150
lines changed

4 files changed

+173
-150
lines changed

ignite/metrics/mean_average_precision.py

Lines changed: 67 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ class MeanAveragePrecision(_BasePrecisionRecall):
2222
def __init__(
2323
self,
2424
rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None,
25-
average_operand: Optional[Literal["precision", "max-precision"]] = "precision",
25+
average: Optional[Literal["precision", "max-precision"]] = "precision",
2626
class_mean: Optional[Literal["micro", "macro", "weighted", "with_other_dims"]] = "macro",
2727
classification_is_multilabel: bool = False,
28-
allow_multiple_recalls_at_single_threshold: bool = False,
2928
output_transform: Callable = lambda x: x,
3029
device: Union[str, torch.device] = torch.device("cpu"),
3130
) -> None:
@@ -38,22 +37,23 @@ def __init__(
3837
Mean average precision is the computed by taking the mean of this average precision over different classes
3938
and possibly some additional dimensions in the detection task.
4039
41-
For detection tasks user must subclass this metric and implement its :meth:`do_matching`
42-
method to provide the metric with desired matching logic. Then this method is called internally in
43-
:meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass and
44-
multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true.
40+
For detection tasks user should use downstream metrics like
41+
:class:`~ignite.metrics.vision.object_detection_map.ObjectDetectionMAP` or subclass this metric and implement
42+
its :meth:`_do_matching` method to provide the metric with desired matching logic. Then this method is called
43+
internally in :meth:`update` method on prediction-target pairs. For classification, all the binary, multiclass
44+
and multilabel data are supported. In the latter case, ``classification_is_multilabel`` should be set to true.
4545
4646
`mean` in the mean average precision accounts for mean of the average precision across classes. ``class_mean``
4747
determines how to take this mean. In the detection tasks, it's possible to take mean of the average precision
4848
in other respects as well e.g. IoU threshold in an object detection task. To this end, average precision
49-
corresponding to each value of IoU thresholds should get measured in :meth:`do_matching`. Please refer to
50-
:meth:`do_matching` for more info on this.
49+
corresponding to each value of IoU thresholds should get measured in :meth:`_do_matching`. Please refer to
50+
:meth:`_do_matching` for more info on this.
5151
5252
Args:
5353
rec_thresholds: recall thresholds (sensivity levels) to be considered for computing Mean Average Precision.
5454
It could be a 1-dim tensor or a sequence of floats. Its values should be between 0 and 1 and don't need
5555
to be sorted. If missing, thresholds are considered automatically using the data.
56-
average_operand: one of values "precision" or "max-precision". In the former case, the precision at a
56+
average: one of values "precision" or "max-precision". In the former case, the precision at a
5757
recall threshold is used for that threshold:
5858
5959
.. math::
@@ -62,7 +62,7 @@ def __init__(
6262
:math:`r` stands for recall thresholds and :math:`P` for precision values. :math:`r_0` is set to zero.
6363
6464
In the latter case, the maximum precision across thresholds greater or equal a recall threshold is
65-
considered as the summation operand; In other words, the precision peek across lower or equall
65+
considered as the summation operand; In other words, the precision peek across lower or equal
6666
sensivity levels is used for a recall threshold:
6767
6868
.. math::
@@ -108,11 +108,6 @@ def __init__(
108108
109109
classification_is_multilabel: Used in classification task and determines if the data
110110
is multilabel or not. Default False.
111-
allow_multiple_recalls_at_single_threshold: When there are predictions with the same scores, it's faster to
112-
consider those predictions associated with different thresholds in the course of measuring recall
113-
values, but it's not logically correct since those predictions are associated with a single threshold,
114-
thus outputing a single recall value. This option is added mainly due to some downstream mAP metrics
115-
which allow such a thing in their computation e.g. pycocotools' mAP. Default False.
116111
output_transform: a callable that is used to transform the
117112
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
118113
form expected by the metric. This can be useful if, for example, you have a multi-output model and
@@ -127,16 +122,14 @@ def __init__(
127122
else:
128123
self.rec_thresholds = None
129124

130-
if average_operand not in ("precision", "max-precision"):
131-
raise ValueError(f"Wrong `average_operand` parameter, given {average_operand}")
132-
self.average_operand = average_operand
125+
if average not in ("precision", "max-precision"):
126+
raise ValueError(f"Wrong `average` parameter, given {average}")
127+
self.average = average
133128

134129
if class_mean is not None and class_mean not in ("micro", "macro", "weighted", "with_other_dims"):
135130
raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}")
136131
self.class_mean = class_mean
137132

138-
self.allow_multiple_recalls_at_single_threshold = allow_multiple_recalls_at_single_threshold
139-
140133
super(_BasePrecisionRecall, self).__init__(
141134
output_transform=output_transform, is_multilabel=classification_is_multilabel, device=device
142135
)
@@ -169,7 +162,7 @@ def reset(self) -> None:
169162
Reset method of the metric
170163
"""
171164
super(_BasePrecisionRecall, self).reset()
172-
if self.do_matching.__func__ == MeanAveragePrecision.do_matching: # type: ignore[attr-defined]
165+
if self._do_matching.__func__ == MeanAveragePrecision._do_matching: # type: ignore[attr-defined]
173166
self._task: Literal["classification", "detection"] = "classification"
174167
else:
175168
self._task = "detection"
@@ -202,7 +195,7 @@ def _check_matching_output_shape(
202195
) -> None:
203196
if not (tps.keys() == fps.keys() == scores.keys()):
204197
raise ValueError(
205-
"Returned TP, FP and scores dictionaries from do_matching should have"
198+
"Returned TP, FP and scores dictionaries from _do_matching should have"
206199
f" the same keys (classes), given {tps.keys()}, {fps.keys()} and {scores.keys()}"
207200
)
208201
try:
@@ -228,7 +221,7 @@ def _check_matching_output_shape(
228221
else:
229222
if self_tp_or_fp[cls][-1].shape[:-1] != new_tp_or_fp[cls].shape[:-1]:
230223
raise ValueError(
231-
f"Tensors in returned {name} from do_matching should not change in shape "
224+
f"Tensors in returned {name} from _do_matching should not change in shape "
232225
"except possibly in the last dimension which is the dimension of samples. Given "
233226
f"{self_tp_or_fp[cls][-1].shape} and {new_tp_or_fp[cls].shape}"
234227
)
@@ -278,13 +271,13 @@ def _classification_prepare_output(
278271

279272
return scores, P
280273

281-
def do_matching(
274+
def _do_matching(
282275
self, pred: Any, target: Any
283276
) -> Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor], Dict[int, int], Dict[int, torch.Tensor]]:
284277
r"""
285278
Matching logic holder of the metric for detection tasks.
286279
287-
User must implement this method by subclassing the metric. There is no constraint on type and shape of
280+
The developer must implement this method by subclassing the metric. There is no constraint on type and shape of
288281
``pred`` and ``target``, but the method should return a quadrople of dictionaries containing TP, FP,
289282
P (actual positive) counts and scores for each class respectively. Please note that class numbers start from
290283
zero.
@@ -315,7 +308,8 @@ def do_matching(
315308
`(TP, FP, P, scores)` A quadrople of true positives, false positives, number of actual positives and scores.
316309
"""
317310
raise NotImplementedError(
318-
"Please subclass MeanAveragePrecision and implement `do_matching` method" " to use the metric in detection."
311+
"Please subclass MeanAveragePrecision and implement `_do_matching` method"
312+
" to use the metric in detection."
319313
)
320314

321315
@reinit__is_reduced
@@ -324,7 +318,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor
324318
325319
Args:
326320
output: a binary tuple. It should consist of prediction and target tensors in the classification case but
327-
for detection it is the same as the implemented-by-user :meth:`do_matching`.
321+
for detection it is the same as the implemented-by-user :meth:`_do_matching`.
328322
329323
For classification, this metric follows the same rules on ``output`` members shape as the
330324
:meth:`Precision.update <precision.Precision.update>` except for ``y_pred`` of binary and multilabel
@@ -341,7 +335,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor
341335
P.to(self._device, dtype=torch.uint8 if self._type != "multiclass" else torch.long)
342336
)
343337
else:
344-
tps, fps, ps, scores_dict = self.do_matching(output[0], output[1])
338+
tps, fps, ps, scores_dict = self._do_matching(output[0], output[1])
345339
self._check_matching_output_shape(tps, fps, scores_dict)
346340
for cls in tps:
347341
self._tp[cls].append(tps[cls].to(device=self._device, dtype=torch.uint8))
@@ -353,7 +347,7 @@ def update(self, output: Union[Tuple[Any, Any], Tuple[torch.Tensor, torch.Tensor
353347
if classes:
354348
self._num_classes = max(max(classes) + 1, self._num_classes)
355349

356-
def _measure_recall_and_precision(
350+
def _compute_recall_and_precision(
357351
self, TP: torch.Tensor, FP: Union[torch.Tensor, None], scores: torch.Tensor, P: torch.Tensor
358352
) -> Tuple[torch.Tensor, torch.Tensor]:
359353
r"""Measuring recall & precision which is the common operation among different settings of the metric.
@@ -363,67 +357,55 @@ def _measure_recall_and_precision(
363357
classification task. ``...`` stands for the additional dimensions in the detection task. Finally,
364358
\#unique scores represents number of unique scores in ``scores`` which is actually the number of thresholds.
365359
366-
This method is called on a per class basis in the detection task and if
367-
``allow_multiple_recalls_at_single_threshold=False``.
368-
369-
=========================== ================================== ===================================
360+
============== ======================
370361
Detection task
371-
--------------------------------------------------------------------------------------------------
372-
**Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False``
373-
=========================== ================================== ===================================
374-
TP and FP (..., N\ :sub:`pred`) (..., N\ :sub:`pred`)
375-
scores (N\ :sub:`pred`,) (N\ :sub:`pred`,)
376-
P () (A single float) () (A single float)
377-
recall (..., N\ :sub:`pred`) (..., \#unique scores)
378-
precision (..., N\ :sub:`pred`) (..., \#unique scores)
379-
=========================== =====================
380-
381-
=========================== ================================== ===================================
362+
-------------------------------------
363+
**Object** **Shape**
364+
============== ======================
365+
TP and FP (..., N\ :sub:`pred`)
366+
scores (N\ :sub:`pred`,)
367+
P () (A single float)
368+
recall (..., \#unique scores)
369+
precision (..., \#unique scores)
370+
============== ======================
371+
372+
=================== =======================================
382373
Classification task
383-
--------------------------------------------------------------------------------------------------
384-
**Object**/ **Condition** ``allow_multiple_recalls...=True`` ``allow_multiple_recalls...=False``
385-
=========================== ================================== ===================================
386-
TP (C, N\ :sub:`pred`) (N\ :sub:`pred`,)
387-
FP (C, N\ :sub:`pred`) None (FP is computed here to be
388-
faster)
389-
scores (C, N\ :sub:`pred`) (N\ :sub:`pred`,)
390-
P (C,) () (A single float)
391-
recall (C, N\ :sub:`pred`) (\#unique scores,)
392-
precision (C, N\ :sub:`pred`) (\#unique scores,)
393-
=========================== ================================== ===================================
374+
-----------------------------------------------------------
375+
**Object** **Shape**
376+
=================== =======================================
377+
TP (N\ :sub:`pred`,)
378+
FP None (FP is computed here to be faster)
379+
scores (N\ :sub:`pred`,)
380+
P () (A single float)
381+
recall (\#unique scores,)
382+
precision (\#unique scores,)
383+
=================== =======================================
394384
395385
Returns:
396386
`(recall, precision)`
397387
"""
398388
indices = torch.argsort(scores, dim=-1, stable=True, descending=True)
399389
tp = TP.take_along_dim(indices, dim=-1) if self._task == "classification" else TP[..., indices]
400390
tp_summation = tp.cumsum(dim=-1).double()
401-
if self._task == "detection" or self.allow_multiple_recalls_at_single_threshold:
402-
fp = (
403-
cast(torch.Tensor, FP).take_along_dim(indices, dim=-1)
404-
if self._task == "classification"
405-
else cast(torch.Tensor, FP)[..., indices]
406-
)
391+
392+
# Adopted from Scikit-learn's implementation
393+
unique_scores_indices = torch.nonzero(
394+
scores.take_along_dim(indices).diff(append=(scores.max() + 1).unsqueeze(dim=0)), as_tuple=True
395+
)[0]
396+
tp_summation = tp_summation[..., unique_scores_indices]
397+
if self._task == "classification":
398+
fp_summation = (unique_scores_indices + 1) - tp_summation
399+
else:
400+
fp = cast(torch.Tensor, FP)[..., indices]
407401
fp_summation = fp.cumsum(dim=-1).double()
408-
if not self.allow_multiple_recalls_at_single_threshold:
409-
# Adopted from Scikit-learn's implementation
410-
unique_scores_indices = torch.nonzero(
411-
scores.take_along_dim(indices).diff(append=(scores.max() + 1).unsqueeze(dim=0)), as_tuple=True
412-
)[0]
413-
tp_summation = tp_summation[..., unique_scores_indices]
414-
if self._task == "classification":
415-
fp_summation = (unique_scores_indices + 1) - tp_summation
416-
else:
417-
fp_summation = fp_summation[..., unique_scores_indices]
402+
fp_summation = fp_summation[..., unique_scores_indices]
418403

419-
if self._task == "classification" and self.allow_multiple_recalls_at_single_threshold:
420-
recall = torch.where(P == 0, 1, tp_summation.T / P).T
421-
elif self._task == "classification" and P == 0:
404+
if self._task == "classification" and P == 0:
422405
recall = torch.ones_like(tp_summation, device=self._device, dtype=torch.bool)
423406
else:
424407
recall = tp_summation / P
425-
# precision = tp_summation / (fp_summation + tp_summation + torch.finfo(torch.double).eps)
426-
# or
408+
427409
predicted_positive = tp_summation + fp_summation
428410
precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive)
429411
return recall, precision
@@ -440,7 +422,7 @@ def _measure_average_precision(self, recall: torch.Tensor, precision: torch.Tens
440422
average_precision: (n-1)-dimensional tensor containing the average precision for mean dimensions.
441423
"""
442424
precision_integrand = (
443-
precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average_operand == "max-precision" else precision
425+
precision.flip(-1).cummax(dim=-1).values.flip(-1) if self.average == "max-precision" else precision
444426
)
445427
if self.rec_thresholds is not None:
446428
rec_thresholds = self.rec_thresholds.repeat((*recall.shape[:-1], 1))
@@ -549,7 +531,7 @@ def compute(self) -> Union[torch.Tensor, float]:
549531
if TP[cls].size(-1) == 0:
550532
average_precisions[cls] = 0
551533
continue
552-
recall, precision = self._measure_recall_and_precision(TP[cls], FP[cls], scores[cls], P[cls])
534+
recall, precision = self._compute_recall_and_precision(TP[cls], FP[cls], scores[cls], P[cls])
553535
average_precision_for_cls_across_other_dims = self._measure_average_precision(recall, precision)
554536
if self.class_mean != "with_other_dims":
555537
average_precisions[cls] = average_precision_for_cls_across_other_dims.mean()
@@ -607,7 +589,7 @@ def compute(self) -> Union[torch.Tensor, float]:
607589
)
608590
)
609591
P = P.sum()
610-
recall, precision = self._measure_recall_and_precision(TP_micro, FP_micro, scores_micro, P)
592+
recall, precision = self._compute_recall_and_precision(TP_micro, FP_micro, scores_micro, P)
611593
return self._measure_average_precision(recall, precision).mean()
612594
else:
613595
rank_P = (
@@ -644,16 +626,12 @@ def compute(self) -> Union[torch.Tensor, float]:
644626
P = P.reshape(1, -1)
645627
scores_classification = scores_classification.view(1, -1)
646628
P_count = P.sum(dim=-1)
647-
if self.allow_multiple_recalls_at_single_threshold:
648-
recall, precision = self._measure_recall_and_precision(P, 1 - P, scores_classification, P_count)
649-
average_precisions = self._measure_average_precision(recall, precision)
650-
else:
651-
average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double)
652-
for cls in range(len(P_count)):
653-
recall, precision = self._measure_recall_and_precision(
654-
P[cls], None, scores_classification[cls], P_count[cls]
655-
)
656-
average_precisions[cls] = self._measure_average_precision(recall, precision)
629+
average_precisions = torch.zeros_like(P_count, device=self._device, dtype=torch.double)
630+
for cls in range(len(P_count)):
631+
recall, precision = self._compute_recall_and_precision(
632+
P[cls], None, scores_classification[cls], P_count[cls]
633+
)
634+
average_precisions[cls] = self._measure_average_precision(recall, precision)
657635
if self._type == "binary":
658636
return average_precisions.item()
659637
if self.class_mean is None:

0 commit comments

Comments
 (0)