Skip to content

Commit 085e0df

Browse files
Fix a bug related to MPS
1 parent d0e82b3 commit 085e0df

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

ignite/metrics/vision/object_detection_average_precision_recall.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,13 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
104104
except ImportError:
105105
raise ModuleNotFoundError("This metric requires torchvision to be installed.")
106106

107-
precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32
108-
109107
if iou_thresholds is None:
110-
iou_thresholds = torch.linspace(0.5, 0.95, 10, device=device, dtype=precision)
108+
iou_thresholds = torch.linspace(0.5, 0.95, 10, dtype=torch.double)
111109

112110
self._iou_thresholds = self._setup_thresholds(iou_thresholds, "iou_thresholds")
113-
self._iou_thresholds = self._iou_thresholds.to(device=device, dtype=precision)
114111

115112
if rec_thresholds is None:
116-
rec_thresholds = torch.linspace(0, 1, 101, device=device, dtype=precision)
117-
118-
self._rec_thresholds = self._setup_thresholds(rec_thresholds, "rec_thresholds")
119-
self._rec_thresholds = self._rec_thresholds.to(device=device, dtype=precision)
113+
rec_thresholds = torch.linspace(0, 1, 101, dtype=torch.double)
120114

121115
self._num_classes = num_classes
122116
self._area_range = area_range
@@ -130,6 +124,8 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
130124
rec_thresholds=rec_thresholds,
131125
class_mean=None,
132126
)
127+
precision = torch.double if torch.device(device) != torch.device("mps") else torch.float32
128+
self.rec_thresholds = self.rec_thresholds.to(device=device, dtype=precision)
133129

134130
@reinit__is_reduced
135131
def reset(self) -> None:

0 commit comments

Comments
 (0)