@@ -174,19 +174,40 @@ def render_gaussians(
174174 )
175175
176176def get_extended_compute_locally (cuda_args , image_height , image_width ):
177- mp_rank = int (cuda_args ["mp_rank" ])
178- dist_global_strategy = [int (x ) for x in cuda_args ["dist_global_strategy" ].split ("," )]
177+ if isinstance (cuda_args ["dist_global_strategy" ], str ):
178+ mp_rank = int (cuda_args ["mp_rank" ])
179+ dist_global_strategy = [int (x ) for x in cuda_args ["dist_global_strategy" ].split ("," )]
180+
181+ num_tile_y = (image_height + 16 - 1 ) // 16 #TODO: this is dangerous because 16 may change.
182+ num_tile_x = (image_width + 16 - 1 ) // 16
183+ tile_l = max (dist_global_strategy [mp_rank ]- num_tile_x - 1 , 0 )
184+ tile_r = min (dist_global_strategy [mp_rank + 1 ]+ num_tile_x + 1 , num_tile_y * num_tile_x )
185+
186+ extended_compute_locally = torch .zeros (num_tile_y * num_tile_x , dtype = torch .bool , device = "cuda" )
187+ extended_compute_locally [tile_l :tile_r ] = True
188+ extended_compute_locally = extended_compute_locally .view (num_tile_y , num_tile_x )
179189
180- num_tile_y = (image_height + 16 - 1 ) // 16 #TODO: this is dangerous because 16 may change.
181- num_tile_x = (image_width + 16 - 1 ) // 16
182- tile_l = max (dist_global_strategy [mp_rank ]- num_tile_x - 1 , 0 )
183- tile_r = min (dist_global_strategy [mp_rank + 1 ]+ num_tile_x + 1 , num_tile_y * num_tile_x )
190+ return extended_compute_locally
191+ else :
192+ division_pos = cuda_args ["dist_global_strategy" ]
193+ division_pos_xs , division_pos_ys = division_pos
194+ mp_rank = int (cuda_args ["mp_rank" ])
195+ grid_size_x = len (division_pos_xs ) - 1
196+ grid_size_y = len (division_pos_ys [0 ]) - 1
197+ y_rank = mp_rank // grid_size_x
198+ x_rank = mp_rank % grid_size_x
199+
200+ local_tile_x_l , local_tile_x_r = division_pos_xs [x_rank ], division_pos_xs [x_rank + 1 ]
201+ local_tile_y_l , local_tile_y_r = division_pos_ys [x_rank ][y_rank ], division_pos_ys [x_rank ][y_rank + 1 ]
184202
185- extended_compute_locally = torch .zeros (num_tile_y * num_tile_x , dtype = torch .bool , device = "cuda" )
186- extended_compute_locally [tile_l :tile_r ] = True
187- extended_compute_locally = extended_compute_locally .view (num_tile_y , num_tile_x )
203+ num_tile_y = (image_height + 16 - 1 ) // 16
204+ num_tile_x = (image_width + 16 - 1 ) // 16
188205
189- return extended_compute_locally
206+ extended_compute_locally = torch .zeros ((num_tile_y , num_tile_x ), dtype = torch .bool , device = "cuda" )
207+ extended_compute_locally [max (local_tile_y_l - 1 ,0 ):min (local_tile_y_r + 1 ,num_tile_y ),
208+ max (local_tile_x_l - 1 ,0 ):min (local_tile_x_r + 1 ,num_tile_x )] = True
209+
210+ return extended_compute_locally
190211
191212class _RenderGaussians (torch .autograd .Function ):
192213 @staticmethod
@@ -367,35 +388,66 @@ def render_gaussians(self, means2D, conic_opacity, rgb, depths, radii, compute_l
367388
368389 def get_local2j_ids (self , means2D , radii , cuda_args ):
369390
370- raster_settings = self .raster_settings
371- mp_world_size = int (cuda_args ["mp_world_size" ])
372- mp_rank = int (cuda_args ["mp_rank" ])
391+ if isinstance (cuda_args ["dist_global_strategy" ], str ):
392+ raster_settings = self .raster_settings
393+ mp_world_size = int (cuda_args ["mp_world_size" ])
394+ mp_rank = int (cuda_args ["mp_rank" ])
373395
374- # TODO: make it more general.
375- dist_global_strategy = [int (x ) for x in cuda_args ["dist_global_strategy" ].split ("," )]
376- assert len (dist_global_strategy ) == mp_world_size + 1 , "dist_global_strategy should have length WORLD_SIZE+1"
377- assert dist_global_strategy [0 ] == 0 , "dist_global_strategy[0] should be 0"
378- dist_global_strategy = torch .tensor (dist_global_strategy , dtype = torch .int , device = means2D .device )
396+ # TODO: make it more general.
397+ dist_global_strategy = [int (x ) for x in cuda_args ["dist_global_strategy" ].split ("," )]
398+ assert len (dist_global_strategy ) == mp_world_size + 1 , "dist_global_strategy should have length WORLD_SIZE+1"
399+ assert dist_global_strategy [0 ] == 0 , "dist_global_strategy[0] should be 0"
400+ dist_global_strategy = torch .tensor (dist_global_strategy , dtype = torch .int , device = means2D .device )
379401
380- args = (
381- raster_settings .image_height ,
382- raster_settings .image_width ,
383- mp_rank ,
384- mp_world_size ,
385- means2D ,
386- radii ,
387- dist_global_strategy ,
388- cuda_args
389- )
402+ args = (
403+ raster_settings .image_height ,
404+ raster_settings .image_width ,
405+ mp_rank ,
406+ mp_world_size ,
407+ means2D ,
408+ radii ,
409+ dist_global_strategy ,
410+ cuda_args
411+ )
412+
413+ local2j_ids_bool = _C .get_local2j_ids_bool (* args ) # local2j_ids_bool is (P, world_size) bool tensor
414+
415+ else :
416+ raster_settings = self .raster_settings
417+ mp_world_size = int (cuda_args ["mp_world_size" ])
418+ mp_rank = int (cuda_args ["mp_rank" ])
390419
391- local2j_ids_bool = _C .get_local2j_ids_bool (* args ) # local2j_ids_bool is (P, world_size) bool tensor
420+ division_pos = cuda_args ["dist_global_strategy" ]
421+ division_pos_xs , division_pos_ys = division_pos
422+
423+ rectangles = []
424+ for y_rank in range (len (division_pos_ys [0 ])- 1 ):
425+ for x_rank in range (len (division_pos_ys )):
426+ local_tile_x_l , local_tile_x_r = division_pos_xs [x_rank ], division_pos_xs [x_rank + 1 ]
427+ local_tile_y_l , local_tile_y_r = division_pos_ys [x_rank ][y_rank ], division_pos_ys [x_rank ][y_rank + 1 ]
428+ rectangles .append ([local_tile_y_l , local_tile_y_r , local_tile_x_l , local_tile_x_r ])
429+ rectangles = torch .tensor (rectangles , dtype = torch .int , device = means2D .device )# (mp_world_size, 4)
430+
431+ args = (
432+ raster_settings .image_height ,
433+ raster_settings .image_width ,
434+ mp_rank ,
435+ mp_world_size ,
436+ means2D ,
437+ radii ,
438+ rectangles ,
439+ cuda_args
440+ )
441+
442+ local2j_ids_bool = _C .get_local2j_ids_bool_adjust_mode6 (* args ) # local2j_ids_bool is (P, world_size) bool tensor
392443
393444 local2j_ids = []
394445 for rk in range (mp_world_size ):
395446 local2j_ids .append (local2j_ids_bool [:, rk ].nonzero ())
396447
397448 return local2j_ids , local2j_ids_bool
398449
450+
399451 def get_distribution_strategy (self , means2D , radii , cuda_args ):
400452
401453 assert False , "This function is not used in the current version."
0 commit comments