@@ -407,35 +407,86 @@ def rotation_matrix(a: TensorType[3], b: TensorType[3]) -> TensorType[3, 3]:
407407 return torch .eye (3 ) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c ) / (s ** 2 + 1e-8 ))
408408
409409
410+ def focus_of_attention (poses : TensorType ["num_poses" :..., 4 , 4 ], initial_focus : TensorType [3 ]) -> TensorType [3 ]:
411+ """Compute the focus of attention of a set of cameras. Only cameras
412+ that have the focus of attention in front of them are considered.
413+ Args:
414+ poses: The poses to orient.
415+ initial_focus: The 3D point views to decide which cameras are initially activated.
416+ Returns:
417+ The 3D position of the focus of attention.
418+ """
419+ # References to the same method in third-party code:
420+ # https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145
421+ # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197
422+ active_directions = - poses [:, :3 , 2 :3 ]
423+ active_origins = poses [:, :3 , 3 :4 ]
424+ # initial value for testing if the focus_pt is in front or behind
425+ focus_pt = initial_focus
426+ # Prune cameras which have the current have the focus_pt behind them.
427+ active = torch .sum (active_directions .squeeze (- 1 ) * (focus_pt - active_origins .squeeze (- 1 )), dim = - 1 ) > 0
428+ done = False
429+ # We need at least two active cameras, else fallback on the previous solution.
430+ # This may be the "poses" solution if no cameras are active on first iteration, e.g.
431+ # they are in an outward-looking configuration.
432+ while torch .sum (active .int ()) > 1 and not done :
433+ active_directions = active_directions [active ]
434+ active_origins = active_origins [active ]
435+ # https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions
436+ m = torch .eye (3 ) - active_directions * torch .transpose (active_directions , - 2 , - 1 )
437+ mt_m = torch .transpose (m , - 2 , - 1 ) @ m
438+ focus_pt = torch .linalg .inv (mt_m .mean (0 )) @ (mt_m @ active_origins ).mean (0 )[:, 0 ]
439+ active = torch .sum (active_directions .squeeze (- 1 ) * (focus_pt - active_origins .squeeze (- 1 )), dim = - 1 ) > 0
440+ if active .all ():
441+ # the set of active cameras did not change, so we're done.
442+ done = True
443+ return focus_pt
444+
445+
410446def auto_orient_and_center_poses (
411- poses : TensorType ["num_poses" :..., 4 , 4 ], method : Literal ["pca" , "up" , "none" ] = "up" , center_poses : bool = True
447+ poses : TensorType ["num_poses" :..., 4 , 4 ], method : Literal ["pca" , "up" , "vertical" , "none" ] = "vertical" ,
448+ center_method : Literal ["poses" , "focus" , "none" ] = "poses" ,
412449) -> TensorType ["num_poses" :..., 3 , 4 ]:
413450 """Orients and centers the poses. We provide two methods for orientation: pca and up.
414451
415- pca: Orient the poses so that the principal component of the points is aligned with the axes.
416- This method works well when all of the cameras are in the same plane.
452+ pca: Orient the poses so that the principal directions of the camera centers are aligned
453+ with the axes, Z corresponding to the smallest principal component.
454+ This method works well when all of the cameras are in the same plane, for example when
455+ images are taken using a mobile robot.
417456 up: Orient the poses so that the average up vector is aligned with the z axis.
418457 This method works well when images are not at arbitrary angles.
458+ vertical: Orient the poses so that the Z 3D direction projects close to the
459+ y axis in images. This method works better if cameras are not all
460+ looking in the same 3D direction, which may happen in camera arrays or in LLFF.
419461
462+ There are two centering methods:
463+ poses: The poses are centered around the origin.
464+ focus: The origin is set to the focus of attention of all cameras (the
465+ closest point to cameras optical axes). Recommended for inward-looking
466+ camera configurations.
420467
421468 Args:
422469 poses: The poses to orient.
423470 method: The method to use for orientation.
424- center_poses: If True, the poses are centered around the origin .
471+ center_method: The method to use to center the poses .
425472
426473 Returns:
427474 The oriented poses.
428475 """
429476
430- translation = poses [..., :3 , 3 ]
477+ origins = poses [..., :3 , 3 ]
431478
432- mean_translation = torch .mean (translation , dim = 0 )
433- translation_diff = translation - mean_translation
479+ mean_origin = torch .mean (origins , dim = 0 )
480+ translation_diff = origins - mean_origin
434481
435- if center_poses :
436- translation = mean_translation
482+ if center_method == "poses" :
483+ translation = mean_origin
484+ elif center_method == "focus" :
485+ translation = focus_of_attention (poses , mean_origin )
486+ elif center_method == "none" :
487+ translation = torch .zeros_like (mean_origin )
437488 else :
438- translation = torch . zeros_like ( mean_translation )
489+ raise ValueError ( f"Unknown value for center_method: { center_method } " )
439490
440491 if method == "pca" :
441492 _ , eigvec = torch .linalg .eigh (translation_diff .T @ translation_diff )
@@ -449,9 +500,41 @@ def auto_orient_and_center_poses(
449500
450501 if oriented_poses .mean (axis = 0 )[2 , 1 ] < 0 :
451502 oriented_poses [:, 1 :3 ] = - 1 * oriented_poses [:, 1 :3 ]
452- elif method == "up" :
503+ elif method in ( "up" , "vertical" ) :
453504 up = torch .mean (poses [:, :3 , 1 ], dim = 0 )
454505 up = up / torch .linalg .norm (up )
506+ if method == "vertical" :
507+ # If cameras are not all parallel (e.g. not in an LLFF configuration),
508+ # we can find the 3D direction that most projects vertically in all
509+ # cameras by minimizing ||Xu|| s.t. ||u||=1. This total least squares
510+ # problem is solved by SVD.
511+ x_axis_matrix = poses [:, :3 , 0 ]
512+ _ , S , Vh = torch .linalg .svd (x_axis_matrix , full_matrices = False )
513+ # Singular values are S_i=||Xv_i|| for each right singular vector v_i.
514+ # ||S|| = sqrt(n) because lines of X are all unit vectors and the v_i
515+ # are an orthonormal basis.
516+ # ||Xv_i|| = sqrt(sum(dot(x_axis_j,v_i)^2)), thus S_i/sqrt(n) is the
517+ # RMS of cosines between x axes and v_i. If the second smallest singular
518+ # value corresponds to an angle error less than 10° (cos(80°)=0.17),
519+ # this is probably a degenerate camera configuration (typical values
520+ # are around 5° average error for the true vertical). In this case,
521+ # rather than taking the vector corresponding to the smallest singular
522+ # value, we project the "up" vector on the plane spanned by the two
523+ # best singular vectors. We could also just fallback to the "up"
524+ # solution.
525+ if S [1 ] > 0.17 * math .sqrt (poses .shape [0 ]):
526+ # regular non-degenerate configuration
527+ up_vertical = Vh [2 , :]
528+ # It may be pointing up or down. Use "up" to disambiguate the sign.
529+ up = up_vertical if torch .dot (up_vertical , up ) > 0 else - up_vertical
530+ else :
531+ # Degenerate configuration: project "up" on the plane spanned by
532+ # the last two right singular vectors (which are orthogonal to the
533+ # first). v_0 is a unit vector, no need to divide by its norm when
534+ # projecting.
535+ up = up - Vh [0 , :] * torch .dot (up , Vh [0 , :])
536+ # re-normalize
537+ up = up / torch .linalg .norm (up )
455538
456539 rotation = rotation_matrix (up , torch .Tensor ([0 , 0 , 1 ]))
457540 transform = torch .cat ([rotation , rotation @ - translation [..., None ]], dim = - 1 )
@@ -461,5 +544,7 @@ def auto_orient_and_center_poses(
461544 transform [:3 , 3 ] = - translation
462545 transform = transform [:3 , :]
463546 oriented_poses = transform @ poses
547+ else :
548+ raise ValueError (f"Unknown value for method: { method } " )
464549
465550 return oriented_poses , transform
0 commit comments