@@ -74,6 +74,88 @@ def test_gaussian_rasterizer_time():
7474 preprocess_time = end_time - start_time
7575 print (f"Time taken by preprocess_gaussians: { preprocess_time :.4f} seconds" )
7676
77+ def test_improved_gaussian_rasterizer ():
78+
79+ # Set up the input data
80+ num_gaussians = 10000
81+ num_batches = 4
82+ means3D = torch .randn (num_gaussians , 3 ).cuda ()
83+ scales = torch .randn (num_gaussians , 3 ).cuda ()
84+ rotations = torch .randn (num_gaussians , 3 , 3 ).cuda ()
85+ shs = torch .randn (num_gaussians , 9 ).cuda ()
86+ opacity = torch .randn (num_gaussians , 1 ).cuda ()
87+
88+ # Set up the viewpoint cameras
89+ batched_viewpoint_cameras = []
90+ for _ in range (num_batches ):
91+ viewpoint_camera = type ('ViewpointCamera' , (), {})
92+ viewpoint_camera .FoVx = math .radians (60 )
93+ viewpoint_camera .FoVy = math .radians (60 )
94+ viewpoint_camera .image_height = 512
95+ viewpoint_camera .image_width = 512
96+ viewpoint_camera .world_view_transform = torch .eye (4 ).cuda ()
97+ viewpoint_camera .full_proj_transform = torch .eye (4 ).cuda ()
98+ viewpoint_camera .camera_center = torch .zeros (3 ).cuda ()
99+ batched_viewpoint_cameras .append (viewpoint_camera )
100+
101+ # Set up the strategies
102+ batched_strategies = [None ] * num_batches
103+
104+ # Set up other parameters
105+ bg_color = torch .ones (3 ).cuda ()
106+ scaling_modifier = 1.0
107+ pc = type ('PC' , (), {})
108+ pc .active_sh_degree = 2
109+ pipe = type ('Pipe' , (), {})
110+ pipe .debug = False
111+ mode = "train"
112+
113+ batched_rasterizers = []
114+ batched_cuda_args = []
115+ batched_screenspace_params = []
116+ batched_means2D = []
117+ batched_radii = []
118+ raster_settings_list = []
119+ for i , (viewpoint_camera , strategy ) in enumerate (zip (batched_viewpoint_cameras , batched_strategies )):
120+ ########## [START] Prepare CUDA Rasterization Settings ##########
121+ cuda_args = get_cuda_args (strategy , mode )
122+ batched_cuda_args .append (cuda_args )
123+
124+ # Set up rasterization configuration
125+ tanfovx = math .tan (viewpoint_camera .FoVx * 0.5 )
126+ tanfovy = math .tan (viewpoint_camera .FoVy * 0.5 )
127+
128+ raster_settings_list .append (GaussianRasterizationSettings (
129+ image_height = int (viewpoint_camera .image_height ),
130+ image_width = int (viewpoint_camera .image_width ),
131+ tanfovx = tanfovx ,
132+ tanfovy = tanfovy ,
133+ bg = bg_color ,
134+ scale_modifier = scaling_modifier ,
135+ viewmatrix = viewpoint_camera .world_view_transform ,
136+ projmatrix = viewpoint_camera .full_proj_transform ,
137+ sh_degree = pc .active_sh_degree ,
138+ campos = viewpoint_camera .camera_center ,
139+ prefiltered = False ,
140+ debug = pipe .debug
141+ ))
142+
143+
144+ rasterizer = GaussianRasterizerBatches (raster_settings = raster_settings_list )
145+ start_time = time .time ()
146+ batched_means2D , batched_rgb , batched_conic_opacity , batched_radii , batched_depths = rasterizer .preprocess_gaussians_batches (
147+ means3D = means3D ,
148+ scales = scales ,
149+ rotations = rotations ,
150+ shs = shs ,
151+ opacities = opacity ,
152+ cuda_args = batched_cuda_args
153+ )
154+ end_time = time .time ()
155+
156+ preprocess_time = end_time - start_time
157+ print (f"Time taken by preprocess_gaussians: { preprocess_time :.4f} seconds" )
158+
77159
78160def test_batched_gaussian_rasterizer ():
79161 # Set up the input data
@@ -163,5 +245,7 @@ def test_batched_gaussian_rasterizer():
163245 # Perform further operations with the batched results
164246 # ...
165247
248+
249+
166250if __name__ == "__main__" :
167251 test_gaussian_rasterizer_time ()
0 commit comments