From 5db5c254f4d145ed56cd7808a0797f8f37a52d3f Mon Sep 17 00:00:00 2001 From: Matthieu Gendrin Date: Wed, 8 Nov 2023 17:47:09 +0100 Subject: [PATCH 1/3] Manage central point different from center of image In some cases, calibration gives central point cx,cy != (0.5,0.5), or it can be decided to crop the input images. In those cases, it is necessary to split fovx to fovXleft,fovXright and fovy to fovYtop,fovYbottom Note that the export of cameras to cameras.json merges those values back to the basic fovx,fovy. This aims at avoiding the modification of diff_gaussian_rasterization branch used for SIBR_gaussianViewer_app. Signed-off-by: Matthieu Gendrin --- gaussian_renderer/__init__.py | 4 +-- scene/cameras.py | 14 ++++++----- scene/dataset_readers.py | 47 ++++++++++++++++++++++++----------- utils/camera_utils.py | 7 +++--- utils/graphics_utils.py | 33 +++++++++++++++--------- 5 files changed, 67 insertions(+), 38 deletions(-) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index f74e336af4..455dbf1c22 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -30,8 +30,8 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, pass # Set up rasterization configuration - tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) - tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + tanfovx = math.tan((viewpoint_camera.FovXright - viewpoint_camera.FovXleft) * 0.5) + tanfovy = math.tan((viewpoint_camera.FovYbottom - viewpoint_camera.FovYtop) * 0.5) raster_settings = GaussianRasterizationSettings( image_height=int(viewpoint_camera.image_height), diff --git a/scene/cameras.py b/scene/cameras.py index abf6e5242b..2cad0b37cf 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -15,7 +15,7 @@ from utils.graphics_utils import getWorld2View2, getProjectionMatrix class Camera(nn.Module): - def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, + def __init__(self, colmap_id, R, T, FovXleft, FovXright, FovYtop, FovYbottom, image, gt_alpha_mask, image_name, uid, trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" ): @@ -25,8 +25,10 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.colmap_id = colmap_id self.R = R self.T = T - self.FoVx = FoVx - self.FoVy = FoVy + self.FovXleft = FovXleft + self.FovXright = FovXright + self.FovYtop = FovYtop + self.FovYbottom = FovYbottom self.image_name = image_name try: @@ -52,7 +54,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.scale = scale self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() - self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovXleft=FovXleft, fovXright=FovXright, fovYtop=FovYtop, fovYbottom=FovYbottom).transpose(0,1).cuda() self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] @@ -60,8 +62,8 @@ class MiniCam: def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): self.image_width = width self.image_height = height - self.FoVy = fovy - self.FoVx = fovx + self.FovYtop = fovy + self.FovXleft = fovx self.znear = znear self.zfar = zfar self.world_view_transform = world_view_transform diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py index 2a6f904a92..83982c9bb9 100644 --- a/scene/dataset_readers.py +++ b/scene/dataset_readers.py @@ -15,7 +15,7 @@ from typing import NamedTuple from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \ read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text -from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +from utils.graphics_utils import getWorld2View2, focal2sidefov, focal2fov, fov2focal import numpy as np import json from pathlib import Path @@ -27,8 +27,10 @@ class CameraInfo(NamedTuple): uid: int R: np.array T: np.array - FovY: np.array - FovX: np.array + FovYtop: np.array + FovYbottom: np.array + FovXleft: np.array + FovXright: np.array image: np.array image_path: str image_name: str @@ -82,15 +84,23 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): R = np.transpose(qvec2rotmat(extr.qvec)) T = np.array(extr.tvec) + cx = intr.params[-2] + cy = intr.params[-1] + if intr.model=="SIMPLE_PINHOLE": focal_length_x = intr.params[0] - FovY = focal2fov(focal_length_x, height) - FovX = focal2fov(focal_length_x, width) + focal_length_y = intr.params[0] + FovYtop = focal2sidefov(focal_length_y, -cy) # usually negative + FovYbottom = focal2sidefov(focal_length_y, height - cy) # usually positive + FovXleft = focal2sidefov(focal_length_x, -cx) # usually negative + FovXright = focal2sidefov(focal_length_x, width - cx) # usually positive elif intr.model=="PINHOLE": focal_length_x = intr.params[0] focal_length_y = intr.params[1] - FovY = focal2fov(focal_length_y, height) - FovX = focal2fov(focal_length_x, width) + FovYtop = focal2sidefov(focal_length_y, -cy) # usually negative + FovYbottom = focal2sidefov(focal_length_y, height - cy) # usually positive + FovXleft = focal2sidefov(focal_length_x, -cx) # usually negative + FovXright = focal2sidefov(focal_length_x, width - cx) # usually positive else: assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" @@ -98,7 +108,7 @@ def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): image_name = os.path.basename(image_path).split(".")[0] image = Image.open(image_path) - cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + cam_info = CameraInfo(uid=uid, R=R, T=T, FovYtop=FovYtop, FovYbottom=FovYbottom, FovXleft=FovXleft, FovXright=FovXright, image=image, image_path=image_path, image_name=image_name, width=width, height=height) cam_infos.append(cam_info) sys.stdout.write('\n') @@ -181,7 +191,6 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= with open(os.path.join(path, transformsfile)) as json_file: contents = json.load(json_file) - fovx = contents["camera_angle_x"] frames = contents["frames"] for idx, frame in enumerate(frames): @@ -209,13 +218,21 @@ def readCamerasFromTransforms(path, transformsfile, white_background, extension= arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4]) image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB") - fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) - FovY = fovy - FovX = fovx - - cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image, + cx = frame["cx"] if "cx" in frame else contents["cx"] if "cx" in contents else image.size[0] / 2 + cy = frame["cy"] if "cy" in frame else contents["cy"] if "cy" in contents else image.size[1] / 2 + fl_y = frame["fl_y"] if "fl_y" in frame else contents["fl_y"] if "fl_y" in contents else None + fl_x = frame["fl_x"] if "fl_x" in frame else contents["fl_x"] if "fl_x" in contents else fl_y + fovx = frame["camera_angle_x"] if "camera_angle_x" in frame else contents["camera_angle_x"] if "camera_angle_x" in contents else None + fovy = frame["camera_angle_y"] if "camera_angle_y" in frame else contents["camera_angle_y"] if "camera_angle_y" in contents else focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) + # priority is given to ("fl_x", "cx") over "camera_angle_x" because it can be frame specific: + fovYtop = focal2sidefov(fl_y, -cy) if fl_y else focal2sidefov(fov2focal(fovy, image.size[1]), -cy) # usually negative + fovYbottom = focal2sidefov(fl_y, image.size[1] - cy) if fl_y else focal2sidefov(fov2focal(fovy, image.size[1]), image.size[1] - cy) # usually positive + fovXleft = focal2sidefov(fl_x, -cx) if fl_x else focal2sidefov(fov2focal(fovx, image.size[0]), -cx) # usually negative + fovXright = focal2sidefov(fl_x, image.size[0] - cx) if fl_x else focal2sidefov(fov2focal(fovx, image.size[0]), image.size[0] - cx) # usually positive + + cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovYtop=fovYtop, FovYbottom=fovYbottom, FovXleft=fovXleft, FovXright=fovXright, image=image, image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1])) - + return cam_infos def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0ada0..74349c81fc 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -13,6 +13,7 @@ import numpy as np from utils.general_utils import PILtoTorch from utils.graphics_utils import fov2focal +import math WARNED = False @@ -47,7 +48,7 @@ def loadCam(args, id, cam_info, resolution_scale): loaded_mask = resized_image_rgb[3:4, ...] return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, - FoVx=cam_info.FovX, FoVy=cam_info.FovY, + FovXleft=cam_info.FovXleft, FovXright=cam_info.FovXright, FovYtop=cam_info.FovYtop, FovYbottom=cam_info.FovYbottom, image=gt_image, gt_alpha_mask=loaded_mask, image_name=cam_info.image_name, uid=id, data_device=args.data_device) @@ -76,7 +77,7 @@ def camera_to_JSON(id, camera : Camera): 'height' : camera.height, 'position': pos.tolist(), 'rotation': serializable_array_2d, - 'fy' : fov2focal(camera.FovY, camera.height), - 'fx' : fov2focal(camera.FovX, camera.width) + 'fy' : camera.height / (2 * math.tan((camera.FovYbottom - camera.FovYtop) / 2)), + 'fx' : camera.width / (2 * math.tan((camera.FovXright - camera.FovXleft) / 2)) } return camera_entry diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py index b4627d837c..91db41ed0b 100644 --- a/utils/graphics_utils.py +++ b/utils/graphics_utils.py @@ -48,30 +48,39 @@ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): Rt = np.linalg.inv(C2W) return np.float32(Rt) -def getProjectionMatrix(znear, zfar, fovX, fovY): - tanHalfFovY = math.tan((fovY / 2)) - tanHalfFovX = math.tan((fovX / 2)) +def getProjectionMatrix(znear, zfar, fovXleft, fovXright, fovYtop, fovYbottom): + tanHalfFovYtop = math.tan(fovYtop) + tanHalfFovYbottom = math.tan(fovYbottom) + tanHalfFovXleft = math.tan(fovXleft) + tanHalfFovXright = math.tan(fovXright) - top = tanHalfFovY * znear - bottom = -top - right = tanHalfFovX * znear - left = -right + top = tanHalfFovYtop * znear + bottom = tanHalfFovYbottom * znear + left = tanHalfFovXleft * znear + right = tanHalfFovXright * znear P = torch.zeros(4, 4) z_sign = 1.0 + # note that my conventions are (fovXleft,fovYtop) negative and (fovXright,fovYbottom) positive P[0, 0] = 2.0 * znear / (right - left) - P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) + P[1, 1] = 2.0 * znear / (bottom - top) + P[0, 2] = -(right + left) / (right - left) + P[1, 2] = -(top + bottom) / (bottom - top) P[3, 2] = z_sign P[2, 2] = z_sign * zfar / (zfar - znear) P[2, 3] = -(zfar * znear) / (zfar - znear) return P def fov2focal(fov, pixels): - return pixels / (2 * math.tan(fov / 2)) + return sidefov2focal(fov / 2, pixels / 2) def focal2fov(focal, pixels): - return 2*math.atan(pixels/(2*focal)) \ No newline at end of file + return 2 * focal2sidefov(focal, pixels / 2) + +def sidefov2focal(sidefov, sidepixels): + return sidepixels / math.tan(sidefov) + +def focal2sidefov(focal, sidepixels): + return math.atan(sidepixels / focal) From 0a948b363fc380c1b6f17e735480c2c3e0d905e6 Mon Sep 17 00:00:00 2001 From: Matthieu Gendrin Date: Mon, 26 Feb 2024 14:16:11 +0100 Subject: [PATCH 2/3] raster_settings pass projection_matrix instead of full_projection_transform to cuda code full_projection_transform is the multiplication of view_matrix and projection_matrix. Thus, passing (view_matrix, projection_matrix) is equivalent to (view_matrix, full_projection_transform). Passing projection_matrix as arguments to rasterize_gaussians enables to get intrinsics informations in the rendering code. Since we'll need central point (cx, cy), this is the aim of this refacto commit. Signed-off-by: Matthieu Gendrin --- gaussian_renderer/__init__.py | 2 +- submodules/diff-gaussian-rasterization | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index 455dbf1c22..04728e459d 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -41,7 +41,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, bg=bg_color, scale_modifier=scaling_modifier, viewmatrix=viewpoint_camera.world_view_transform, - projmatrix=viewpoint_camera.full_proj_transform, + projmatrix=viewpoint_camera.projection_matrix, sh_degree=pc.active_sh_degree, campos=viewpoint_camera.camera_center, prefiltered=False, diff --git a/submodules/diff-gaussian-rasterization b/submodules/diff-gaussian-rasterization index 59f5f77e3d..e8b2476380 160000 --- a/submodules/diff-gaussian-rasterization +++ b/submodules/diff-gaussian-rasterization @@ -1 +1 @@ -Subproject commit 59f5f77e3ddbac3ed9db93ec2cfe99ed6c5d121d +Subproject commit e8b24763806263493f74deadf5f18b68f7cab0e1 From 54f1ca8815a1edd0b48d905aa38af3d26bfb1888 Mon Sep 17 00:00:00 2001 From: Matthieu Gendrin Date: Tue, 27 Feb 2024 10:33:51 +0100 Subject: [PATCH 3/3] remove ref to tan_fovx and tan_fovy fovx and fovy are not relevant anymore. Intrinsics are read from projection_matrix. Thus, we don't pass them to the rendering code. Signed-off-by: Matthieu Gendrin --- gaussian_renderer/__init__.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index 04728e459d..3582c10ac9 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -29,15 +29,9 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, except: pass - # Set up rasterization configuration - tanfovx = math.tan((viewpoint_camera.FovXright - viewpoint_camera.FovXleft) * 0.5) - tanfovy = math.tan((viewpoint_camera.FovYbottom - viewpoint_camera.FovYtop) * 0.5) - raster_settings = GaussianRasterizationSettings( image_height=int(viewpoint_camera.image_height), image_width=int(viewpoint_camera.image_width), - tanfovx=tanfovx, - tanfovy=tanfovy, bg=bg_color, scale_modifier=scaling_modifier, viewmatrix=viewpoint_camera.world_view_transform,