1717class Camera (nn .Module ):
1818 def __init__ (self , colmap_id , R , T , FoVx , FoVy , image , gt_alpha_mask ,
1919 image_name , uid ,
20- trans = np .array ([0.0 , 0.0 , 0.0 ]), scale = 1.0 , data_device = "cuda"
20+ trans = np .array ([0.0 , 0.0 , 0.0 ]), scale = 1.0 , data_device = "cuda" , lazy_load = False
2121 ):
2222 super (Camera , self ).__init__ ()
2323
@@ -36,14 +36,17 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
3636 print (f"[Warning] Custom device { data_device } failed, fallback to default cuda device" )
3737 self .data_device = torch .device ("cuda" )
3838
39- self .original_image = image .clamp (0.0 , 1.0 )
39+ if lazy_load :
40+ self .data_device = torch .device ("cpu" )
41+
42+ self .original_image = image .clamp (0.0 , 1.0 ).to (self .data_device )
4043 self .image_width = self .original_image .shape [2 ]
4144 self .image_height = self .original_image .shape [1 ]
4245
4346 if gt_alpha_mask is not None :
44- self .original_image *= gt_alpha_mask
47+ self .original_image *= gt_alpha_mask . to ( self . data_device )
4548 else :
46- self .original_image *= torch .ones ((1 , self .image_height , self .image_width ))
49+ self .original_image *= torch .ones ((1 , self .image_height , self .image_width ), device = self . data_device )
4750
4851 self .zfar = 100.0
4952 self .znear = 0.01
0 commit comments