Skip to content

Commit d802de4

Browse files
author
Hexu Zhao
committed
support adjust_mode==6; However, I do not know why it is slower than adjust_mode==5.
1 parent f39d729 commit d802de4

File tree

4 files changed

+169
-29
lines changed

4 files changed

+169
-29
lines changed

diff_gaussian_rasterization/__init__.py

Lines changed: 81 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -174,19 +174,40 @@ def render_gaussians(
174174
)
175175

176176
def get_extended_compute_locally(cuda_args, image_height, image_width):
177-
mp_rank = int(cuda_args["mp_rank"])
178-
dist_global_strategy = [int(x) for x in cuda_args["dist_global_strategy"].split(",")]
177+
if isinstance(cuda_args["dist_global_strategy"], str):
178+
mp_rank = int(cuda_args["mp_rank"])
179+
dist_global_strategy = [int(x) for x in cuda_args["dist_global_strategy"].split(",")]
180+
181+
num_tile_y = (image_height + 16 - 1) // 16 #TODO: this is dangerous because 16 may change.
182+
num_tile_x = (image_width + 16 - 1) // 16
183+
tile_l = max(dist_global_strategy[mp_rank]-num_tile_x-1, 0)
184+
tile_r = min(dist_global_strategy[mp_rank+1]+num_tile_x+1, num_tile_y*num_tile_x)
185+
186+
extended_compute_locally = torch.zeros(num_tile_y*num_tile_x, dtype=torch.bool, device="cuda")
187+
extended_compute_locally[tile_l:tile_r] = True
188+
extended_compute_locally = extended_compute_locally.view(num_tile_y, num_tile_x)
179189

180-
num_tile_y = (image_height + 16 - 1) // 16 #TODO: this is dangerous because 16 may change.
181-
num_tile_x = (image_width + 16 - 1) // 16
182-
tile_l = max(dist_global_strategy[mp_rank]-num_tile_x-1, 0)
183-
tile_r = min(dist_global_strategy[mp_rank+1]+num_tile_x+1, num_tile_y*num_tile_x)
190+
return extended_compute_locally
191+
else:
192+
division_pos = cuda_args["dist_global_strategy"]
193+
division_pos_xs, division_pos_ys = division_pos
194+
mp_rank = int(cuda_args["mp_rank"])
195+
grid_size_x = len(division_pos_xs) - 1
196+
grid_size_y = len(division_pos_ys[0]) - 1
197+
y_rank = mp_rank // grid_size_x
198+
x_rank = mp_rank % grid_size_x
199+
200+
local_tile_x_l, local_tile_x_r = division_pos_xs[x_rank], division_pos_xs[x_rank+1]
201+
local_tile_y_l, local_tile_y_r = division_pos_ys[x_rank][y_rank], division_pos_ys[x_rank][y_rank+1]
184202

185-
extended_compute_locally = torch.zeros(num_tile_y*num_tile_x, dtype=torch.bool, device="cuda")
186-
extended_compute_locally[tile_l:tile_r] = True
187-
extended_compute_locally = extended_compute_locally.view(num_tile_y, num_tile_x)
203+
num_tile_y = (image_height + 16 - 1) // 16
204+
num_tile_x = (image_width + 16 - 1) // 16
188205

189-
return extended_compute_locally
206+
extended_compute_locally = torch.zeros((num_tile_y, num_tile_x), dtype=torch.bool, device="cuda")
207+
extended_compute_locally[max(local_tile_y_l-1,0):min(local_tile_y_r+1,num_tile_y),
208+
max(local_tile_x_l-1,0):min(local_tile_x_r+1,num_tile_x)] = True
209+
210+
return extended_compute_locally
190211

191212
class _RenderGaussians(torch.autograd.Function):
192213
@staticmethod
@@ -367,35 +388,66 @@ def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_l
367388

368389
def get_local2j_ids(self, means2D, radii, cuda_args):
369390

370-
raster_settings = self.raster_settings
371-
mp_world_size = int(cuda_args["mp_world_size"])
372-
mp_rank = int(cuda_args["mp_rank"])
391+
if isinstance(cuda_args["dist_global_strategy"], str):
392+
raster_settings = self.raster_settings
393+
mp_world_size = int(cuda_args["mp_world_size"])
394+
mp_rank = int(cuda_args["mp_rank"])
373395

374-
# TODO: make it more general.
375-
dist_global_strategy = [int(x) for x in cuda_args["dist_global_strategy"].split(",")]
376-
assert len(dist_global_strategy) == mp_world_size+1, "dist_global_strategy should have length WORLD_SIZE+1"
377-
assert dist_global_strategy[0] == 0, "dist_global_strategy[0] should be 0"
378-
dist_global_strategy = torch.tensor(dist_global_strategy, dtype=torch.int, device=means2D.device)
396+
# TODO: make it more general.
397+
dist_global_strategy = [int(x) for x in cuda_args["dist_global_strategy"].split(",")]
398+
assert len(dist_global_strategy) == mp_world_size+1, "dist_global_strategy should have length WORLD_SIZE+1"
399+
assert dist_global_strategy[0] == 0, "dist_global_strategy[0] should be 0"
400+
dist_global_strategy = torch.tensor(dist_global_strategy, dtype=torch.int, device=means2D.device)
379401

380-
args = (
381-
raster_settings.image_height,
382-
raster_settings.image_width,
383-
mp_rank,
384-
mp_world_size,
385-
means2D,
386-
radii,
387-
dist_global_strategy,
388-
cuda_args
389-
)
402+
args = (
403+
raster_settings.image_height,
404+
raster_settings.image_width,
405+
mp_rank,
406+
mp_world_size,
407+
means2D,
408+
radii,
409+
dist_global_strategy,
410+
cuda_args
411+
)
412+
413+
local2j_ids_bool = _C.get_local2j_ids_bool(*args) # local2j_ids_bool is (P, world_size) bool tensor
414+
415+
else:
416+
raster_settings = self.raster_settings
417+
mp_world_size = int(cuda_args["mp_world_size"])
418+
mp_rank = int(cuda_args["mp_rank"])
390419

391-
local2j_ids_bool = _C.get_local2j_ids_bool(*args) # local2j_ids_bool is (P, world_size) bool tensor
420+
division_pos = cuda_args["dist_global_strategy"]
421+
division_pos_xs, division_pos_ys = division_pos
422+
423+
rectangles = []
424+
for y_rank in range(len(division_pos_ys[0])-1):
425+
for x_rank in range(len(division_pos_ys)):
426+
local_tile_x_l, local_tile_x_r = division_pos_xs[x_rank], division_pos_xs[x_rank+1]
427+
local_tile_y_l, local_tile_y_r = division_pos_ys[x_rank][y_rank], division_pos_ys[x_rank][y_rank+1]
428+
rectangles.append([local_tile_y_l, local_tile_y_r, local_tile_x_l, local_tile_x_r])
429+
rectangles = torch.tensor(rectangles, dtype=torch.int, device=means2D.device)# (mp_world_size, 4)
430+
431+
args = (
432+
raster_settings.image_height,
433+
raster_settings.image_width,
434+
mp_rank,
435+
mp_world_size,
436+
means2D,
437+
radii,
438+
rectangles,
439+
cuda_args
440+
)
441+
442+
local2j_ids_bool = _C.get_local2j_ids_bool_adjust_mode6(*args) # local2j_ids_bool is (P, world_size) bool tensor
392443

393444
local2j_ids = []
394445
for rk in range(mp_world_size):
395446
local2j_ids.append(local2j_ids_bool[:, rk].nonzero())
396447

397448
return local2j_ids, local2j_ids_bool
398449

450+
399451
def get_distribution_strategy(self, means2D, radii, cuda_args):
400452

401453
assert False, "This function is not used in the current version."

ext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
2121
m.def("render_gaussians", &RenderGaussiansCUDA);
2222
m.def("render_gaussians_backward", &RenderGaussiansBackwardCUDA);
2323
m.def("get_local2j_ids_bool", &GetLocal2jIdsBoolCUDA);
24+
m.def("get_local2j_ids_bool_adjust_mode6", &GetLocal2jIdsBoolAdjustMode6CUDA);
2425

2526
// Image Distribution Utilities
2627
m.def("get_touched_locally", &GetTouchedLocally);

rasterize_points.cu

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,83 @@ torch::Tensor GetLocal2jIdsBoolCUDA(
483483
}
484484

485485

486+
__global__ void getTouchedIdsBoolAdjustMode6(
487+
int P,
488+
int height,
489+
int width,
490+
int world_size,
491+
const float2* means2D,
492+
const int* radii,// NOTE: radii is not const in getRect()
493+
const int* rectangles,
494+
bool* touchedIdsBool,
495+
bool avoid_pixel_all2all)
496+
{
497+
auto i = cg::this_grid().thread_rank();
498+
if (i < P)
499+
{
500+
uint2 rect_min, rect_max;
501+
dim3 tile_grid((width + BLOCK_X - 1) / BLOCK_X, (height + BLOCK_Y - 1) / BLOCK_Y, 1);
502+
503+
getRect(means2D[i], radii[i], rect_min, rect_max, tile_grid);
504+
505+
for (int rk = 0; rk < world_size; rk++)
506+
{
507+
// local_tile_y_l, local_tile_y_r, local_tile_x_l, local_tile_x_r
508+
const int* rectangles_offset = rectangles+(rk*4);
509+
int local_tile_y_l = *(rectangles_offset);
510+
int local_tile_y_r = *(rectangles_offset+1);
511+
int local_tile_x_l = *(rectangles_offset+2);
512+
int local_tile_x_r = *(rectangles_offset+3);
513+
514+
515+
516+
if (avoid_pixel_all2all) {
517+
if (local_tile_y_l>0) local_tile_y_l-=1;
518+
if (local_tile_x_l>0) local_tile_x_l-=1;//WERID: If local_tile_x_l changes to -1, then it gives weird behavior and I have not figure it out yet.
519+
local_tile_y_r+=1;
520+
local_tile_x_r+=1;
521+
}
522+
if (rect_max.y <= local_tile_y_l ||
523+
local_tile_y_r <= rect_min.y ||
524+
rect_max.x <= local_tile_x_l ||
525+
local_tile_x_r <= rect_min.x) continue;
526+
527+
touchedIdsBool[i * world_size + rk] = true;
528+
}
529+
}
530+
}
531+
532+
torch::Tensor GetLocal2jIdsBoolAdjustMode6CUDA(
533+
int image_height,
534+
int image_width,
535+
int mp_rank,
536+
int mp_world_size,
537+
const torch::Tensor& means2D,
538+
const torch::Tensor& radii,
539+
const torch::Tensor& rectangles,
540+
const pybind11::dict &args)
541+
{
542+
const int P = means2D.size(0);
543+
const int H = image_height;
544+
const int W = image_width;
545+
bool avoid_pixel_all2all = args["avoid_pixel_all2all"].cast<bool>();
486546

547+
torch::Tensor local2jIdsBool = torch::full({P, mp_world_size}, false, means2D.options().dtype(torch::kBool));
548+
549+
getTouchedIdsBoolAdjustMode6 << <(P + ONE_DIM_BLOCK_SIZE - 1) / ONE_DIM_BLOCK_SIZE, ONE_DIM_BLOCK_SIZE >> >(
550+
P,
551+
H,
552+
W,
553+
mp_world_size,
554+
reinterpret_cast<float2*>(means2D.contiguous().data<float>()),
555+
radii.contiguous().data<int>(),
556+
rectangles.contiguous().data<int>(),
557+
local2jIdsBool.contiguous().data<bool>(),
558+
avoid_pixel_all2all
559+
);
560+
561+
return local2jIdsBool;
562+
}
487563

488564

489565

rasterize_points.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,15 @@ torch::Tensor GetLocal2jIdsBoolCUDA(
177177
const torch::Tensor& dist_global_strategy,
178178
const pybind11::dict &args);
179179

180+
torch::Tensor GetLocal2jIdsBoolAdjustMode6CUDA(
181+
int image_height,
182+
int image_width,
183+
int mp_rank,
184+
int mp_world_size,
185+
const torch::Tensor& means2D,
186+
const torch::Tensor& radii,
187+
const torch::Tensor& rectangles,
188+
const pybind11::dict &args);
189+
190+
180191
std::tuple<int, int, int> GetBlockXY();

0 commit comments

Comments
 (0)