2323
2424class GaussianModel :
2525
26- def setup_functions (self ):
26+ def setup_functions (self , dtype ):
2727 def build_covariance_from_scaling_rotation (scaling , scaling_modifier , rotation ):
28- L = build_scaling_rotation (scaling_modifier * scaling , rotation )
28+ L = build_scaling_rotation (scaling_modifier * scaling , rotation , dtype )
2929 actual_covariance = L @ L .transpose (1 , 2 )
30- symm = strip_symmetric (actual_covariance )
30+ symm = strip_symmetric (actual_covariance , dtype )
3131 return symm
3232
3333 self .scaling_activation = torch .exp
@@ -41,7 +41,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
4141 self .rotation_activation = torch .nn .functional .normalize
4242
4343
44- def __init__ (self , sh_degree : int ):
44+ def __init__ (self , sh_degree : int , dtype = torch . float32 ):
4545 self .active_sh_degree = 0
4646 self .max_sh_degree = sh_degree
4747 self ._xyz = torch .empty (0 )
@@ -56,7 +56,8 @@ def __init__(self, sh_degree : int):
5656 self .optimizer = None
5757 self .percent_dense = 0
5858 self .spatial_lr_scale = 0
59- self .setup_functions ()
59+ self .dtype = dtype
60+ self .setup_functions (dtype )
6061
6162 def capture (self ):
6263 return (
@@ -136,15 +137,15 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
136137 rots = torch .zeros ((fused_point_cloud .shape [0 ], 4 ), device = "cuda" )
137138 rots [:, 0 ] = 1
138139
139- opacities = inverse_sigmoid (0.1 * torch .ones ((fused_point_cloud .shape [0 ], 1 ), dtype = torch . float , device = "cuda" ))
140+ opacities = inverse_sigmoid (0.1 * torch .ones ((fused_point_cloud .shape [0 ], 1 ), dtype = self . dtype , device = "cuda" ))
140141
141142 self ._xyz = nn .Parameter (fused_point_cloud .requires_grad_ (True ))
142143 self ._features_dc = nn .Parameter (features [:,:,0 :1 ].transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
143144 self ._features_rest = nn .Parameter (features [:,:,1 :].transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
144145 self ._scaling = nn .Parameter (scales .requires_grad_ (True ))
145146 self ._rotation = nn .Parameter (rots .requires_grad_ (True ))
146147 self ._opacity = nn .Parameter (opacities .requires_grad_ (True ))
147- self .max_radii2D = torch .zeros ((self .get_xyz .shape [0 ]), device = "cuda" )
148+ self .max_radii2D = torch .zeros ((self .get_xyz .shape [0 ]), device = "cuda" , dtype = self . dtype )
148149
149150 def training_setup (self , training_args ):
150151 self .percent_dense = training_args .percent_dense
@@ -246,12 +247,12 @@ def load_ply(self, path):
246247 for idx , attr_name in enumerate (rot_names ):
247248 rots [:, idx ] = np .asarray (plydata .elements [0 ][attr_name ])
248249
249- self ._xyz = nn .Parameter (torch .tensor (xyz , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
250- self ._features_dc = nn .Parameter (torch .tensor (features_dc , dtype = torch . float , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
251- self ._features_rest = nn .Parameter (torch .tensor (features_extra , dtype = torch . float , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
252- self ._opacity = nn .Parameter (torch .tensor (opacities , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
253- self ._scaling = nn .Parameter (torch .tensor (scales , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
254- self ._rotation = nn .Parameter (torch .tensor (rots , dtype = torch . float , device = "cuda" ).requires_grad_ (True ))
250+ self ._xyz = nn .Parameter (torch .tensor (xyz , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
251+ self ._features_dc = nn .Parameter (torch .tensor (features_dc , dtype = self . dtype , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
252+ self ._features_rest = nn .Parameter (torch .tensor (features_extra , dtype = self . dtype , device = "cuda" ).transpose (1 , 2 ).contiguous ().requires_grad_ (True ))
253+ self ._opacity = nn .Parameter (torch .tensor (opacities , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
254+ self ._scaling = nn .Parameter (torch .tensor (scales , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
255+ self ._rotation = nn .Parameter (torch .tensor (rots , dtype = self . dtype , device = "cuda" ).requires_grad_ (True ))
255256
256257 self .active_sh_degree = self .max_sh_degree
257258
0 commit comments