@@ -104,19 +104,13 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
104104 except ImportError :
105105 raise ModuleNotFoundError ("This metric requires torchvision to be installed." )
106106
107- precision = torch .double if torch .device (device ) != torch .device ("mps" ) else torch .float32
108-
109107 if iou_thresholds is None :
110- iou_thresholds = torch .linspace (0.5 , 0.95 , 10 , device = device , dtype = precision )
108+ iou_thresholds = torch .linspace (0.5 , 0.95 , 10 , dtype = torch . double )
111109
112110 self ._iou_thresholds = self ._setup_thresholds (iou_thresholds , "iou_thresholds" )
113- self ._iou_thresholds = self ._iou_thresholds .to (device = device , dtype = precision )
114111
115112 if rec_thresholds is None :
116- rec_thresholds = torch .linspace (0 , 1 , 101 , device = device , dtype = precision )
117-
118- self ._rec_thresholds = self ._setup_thresholds (rec_thresholds , "rec_thresholds" )
119- self ._rec_thresholds = self ._rec_thresholds .to (device = device , dtype = precision )
113+ rec_thresholds = torch .linspace (0 , 1 , 101 , dtype = torch .double )
120114
121115 self ._num_classes = num_classes
122116 self ._area_range = area_range
@@ -130,6 +124,8 @@ def box_iou(pred_boxes: torch.Tensor, gt_boxes: torch.Tensor, iscrowd: torch.Boo
130124 rec_thresholds = rec_thresholds ,
131125 class_mean = None ,
132126 )
127+ precision = torch .double if torch .device (device ) != torch .device ("mps" ) else torch .float32
128+ self .rec_thresholds = self .rec_thresholds .to (device = device , dtype = precision )
133129
134130 @reinit__is_reduced
135131 def reset (self ) -> None :
0 commit comments