@@ -451,54 +451,45 @@ def _parallelogram_to_bounding_boxes(parallelogram: torch.Tensor) -> torch.Tenso
451451 torch.Tensor: Tensor of same shape as input containing the rectangle coordinates.
452452 The output maintains the same dtype as the input.
453453 """
454+ original_shape = parallelogram .shape
454455 dtype = parallelogram .dtype
455456 acceptable_dtypes = [torch .float32 , torch .float64 ]
456457 need_cast = dtype not in acceptable_dtypes
457458 if need_cast :
458459 # Up-case to avoid overflow for square operations
459460 parallelogram = parallelogram .to (torch .float32 )
460- out_boxes = parallelogram .clone ()
461-
462- # Calculate parallelogram diagonal vectors
463- dx13 = parallelogram [..., 4 ] - parallelogram [..., 0 ]
464- dy13 = parallelogram [..., 5 ] - parallelogram [..., 1 ]
465- dx42 = parallelogram [..., 2 ] - parallelogram [..., 6 ]
466- dy42 = parallelogram [..., 3 ] - parallelogram [..., 7 ]
467- dx12 = parallelogram [..., 2 ] - parallelogram [..., 0 ]
468- dy12 = parallelogram [..., 1 ] - parallelogram [..., 3 ]
469- diag13 = torch .sqrt (dx13 ** 2 + dy13 ** 2 )
470- diag24 = torch .sqrt (dx42 ** 2 + dy42 ** 2 )
471- mask = diag13 > diag24
472-
473- # Calculate rotation angle in radians
474- r_rad = torch .atan2 (dy12 , dx12 )
475- cos , sin = torch .cos (r_rad ), torch .sin (r_rad )
476-
477- # Calculate width using the angle between diagonal and rotation
478- w = torch .where (
479- mask ,
480- diag13 * torch .abs (torch .sin (torch .atan2 (dx13 , dy13 ) - r_rad )),
481- diag24 * torch .abs (torch .sin (torch .atan2 (dx42 , dy42 ) - r_rad )),
482- )
483461
484- delta_x = w * cos
485- delta_y = w * sin
486- # Update coordinates to form a rectangle
487- # Keeping the points (x1, y1) and (x3, y3) unchanged.
488- out_boxes [..., 2 ] = torch .where (mask , parallelogram [..., 0 ] + delta_x , parallelogram [..., 2 ])
489- out_boxes [..., 3 ] = torch .where (mask , parallelogram [..., 1 ] - delta_y , parallelogram [..., 3 ])
490- out_boxes [..., 6 ] = torch .where (mask , parallelogram [..., 4 ] - delta_x , parallelogram [..., 6 ])
491- out_boxes [..., 7 ] = torch .where (mask , parallelogram [..., 5 ] + delta_y , parallelogram [..., 7 ])
492-
493- # Keeping the points (x2, y2) and (x4, y4) unchanged.
494- out_boxes [..., 0 ] = torch .where (~ mask , parallelogram [..., 2 ] - delta_x , parallelogram [..., 0 ])
495- out_boxes [..., 1 ] = torch .where (~ mask , parallelogram [..., 3 ] + delta_y , parallelogram [..., 1 ])
496- out_boxes [..., 4 ] = torch .where (~ mask , parallelogram [..., 6 ] + delta_x , parallelogram [..., 4 ])
497- out_boxes [..., 5 ] = torch .where (~ mask , parallelogram [..., 7 ] - delta_y , parallelogram [..., 5 ])
462+ x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 = parallelogram .unbind (- 1 )
463+ cx = (x1 + x3 ) / 2
464+ cy = (y1 + y3 ) / 2
465+
466+ # Calculate width, height, and rotation angle of the parallelogram
467+ wp = torch .sqrt ((x2 - x1 ) ** 2 + (y2 - y1 ) ** 2 )
468+ hp = torch .sqrt ((x4 - x1 ) ** 2 + (y4 - y1 ) ** 2 )
469+ r12 = torch .atan2 (y1 - y2 , x2 - x1 )
470+ r14 = torch .atan2 (y1 - y4 , x4 - x1 )
471+ r_rad = r12 - r14
472+ sign = torch .where (r_rad > torch .pi / 2 , - 1 , 1 )
473+ cos , sin = r_rad .cos (), r_rad .sin ()
474+
475+ # Calculate width, height, and rotation angle of the rectangle
476+ w = torch .where (wp < hp , wp * sin , wp + hp * cos * sign )
477+ h = torch .where (wp > hp , hp * sin , hp + wp * cos * sign )
478+ r_rad = torch .where (hp > wp , r14 + torch .pi / 2 , r12 )
479+ cos , sin = r_rad .cos (), r_rad .sin ()
480+
481+ x1 = cx - w / 2 * cos - h / 2 * sin
482+ y1 = cy - h / 2 * cos + w / 2 * sin
483+ x2 = cx + w / 2 * cos - h / 2 * sin
484+ y2 = cy - h / 2 * cos - w / 2 * sin
485+ x3 = cx + w / 2 * cos + h / 2 * sin
486+ y3 = cy + h / 2 * cos - w / 2 * sin
487+ x4 = cx - w / 2 * cos + h / 2 * sin
488+ y4 = cy + h / 2 * cos + w / 2 * sin
489+ out_boxes = torch .stack ((x1 , y1 , x2 , y2 , x3 , y3 , x4 , y4 ), dim = - 1 ).reshape (original_shape )
498490
499491 if need_cast :
500492 out_boxes = out_boxes .to (dtype )
501-
502493 return out_boxes
503494
504495
0 commit comments