1313import torch .nn as nn
1414import torch
1515from . import _C
16+ import time
1617
1718def cpu_deep_copy_tuple (input_tuple ):
1819 copied_tensors = [item .cpu ().clone () if isinstance (item , torch .Tensor ) else item for item in input_tuple ]
@@ -202,7 +203,7 @@ def forward(
202203 ):
203204
204205 # means2D = means2D[:,:2].contiguous()
205- # TODO : double check.
206+ # NOTE : double check.
206207 # means2D is padded to (P, 3) before being output from preprocess_gaussians.
207208 # because _RenderGaussians.backward will give dL_dmeans2D with shape (P, 3).
208209 # Here, because the means2D in cuda code is (P, 2), we need to remove the padding.
@@ -228,12 +229,19 @@ def forward(
228229 cuda_args
229230 )
230231
232+ # This time is to measure: render forward+loss forward+loss backward+render backward . And then used to do load balancing.
233+ # Used when cuda_args["avoid_pixel_all2all"] is True.
234+ # It is not useful for now.
235+ # torch.cuda.synchronize()
236+ # render_forward_start_time = time.time()
237+
231238 num_rendered , color , n_render , n_consider , n_contrib , geomBuffer , binningBuffer , imgBuffer = _C .render_gaussians (* args )
232239
233240 # Keep relevant tensors for backward
234241 ctx .raster_settings = raster_settings
235242 ctx .cuda_args = cuda_args
236243 ctx .num_rendered = num_rendered
244+ # ctx.render_forward_start_time = render_forward_start_time
237245 ctx .save_for_backward (means2D , conic_opacity , rgb , geomBuffer , binningBuffer , imgBuffer , compute_locally , extended_compute_locally )
238246 ctx .mark_non_differentiable (n_render , n_consider , n_contrib )
239247
@@ -247,6 +255,7 @@ def backward(ctx, grad_color, grad_n_render, grad_n_consider, grad_n_contrib):
247255 num_rendered = ctx .num_rendered
248256 raster_settings = ctx .raster_settings
249257 cuda_args = ctx .cuda_args
258+ # render_forward_start_time = ctx.render_forward_start_time
250259 means2D , conic_opacity , rgb , geomBuffer , binningBuffer , imgBuffer , compute_locally , extended_compute_locally = ctx .saved_tensors
251260
252261 # Restructure args as C++ method expects them
@@ -265,6 +274,11 @@ def backward(ctx, grad_color, grad_n_render, grad_n_consider, grad_n_contrib):
265274
266275 dL_dmeans2D , dL_dconic_opacity , dL_dcolors = _C .render_gaussians_backward (* args )
267276
277+ # Used when cuda_args["avoid_pixel_all2all"] is True.
278+ # torch.cuda.synchronize()
279+ # render_backward_end_time = time.time()
280+ # cuda_args["stats_collector"]["pixelwise_workloads_time"] = (render_backward_end_time - render_forward_start_time)*1000
281+
268282 # change dL_dmeans2D from (P, 3) to (P, 2)
269283 # dL_dmeans2D is now (P, 3) because of render backwards' cuda implementation.
270284 dL_dmeans2D = dL_dmeans2D [:,:2 ]
0 commit comments