@@ -110,7 +110,7 @@ def __init__(
110110 self .qkv = nn .Linear (dim , dim * 3 , bias = False )
111111 if qkv_bias :
112112 self .q_bias = nn .Parameter (torch .zeros (dim ))
113- self .register_buffer ('k_bias' , torch .zeros (dim ), persistent = False )
113+ self .register_buffer ('k_bias' , torch .zeros (dim , device = "cpu" ), persistent = False )
114114 self .v_bias = nn .Parameter (torch .zeros (dim ))
115115 else :
116116 self .q_bias = None
@@ -125,8 +125,8 @@ def __init__(
125125
126126 def _make_pair_wise_relative_positions (self ):
127127 # get relative_coords_table
128- relative_coords_h = torch .arange (- (self .window_size [0 ] - 1 ), self .window_size [0 ]).to (torch .float32 )
129- relative_coords_w = torch .arange (- (self .window_size [1 ] - 1 ), self .window_size [1 ]).to (torch .float32 )
128+ relative_coords_h = torch .arange (- (self .window_size [0 ] - 1 ), self .window_size [0 ], device = "cpu" ).to (torch .float32 )
129+ relative_coords_w = torch .arange (- (self .window_size [1 ] - 1 ), self .window_size [1 ], device = "cpu" ).to (torch .float32 )
130130 relative_coords_table = torch .stack (ndgrid (relative_coords_h , relative_coords_w ))
131131 relative_coords_table = relative_coords_table .permute (1 , 2 , 0 ).contiguous ().unsqueeze (0 ) # 1, 2*Wh-1, 2*Ww-1, 2
132132 if self .pretrained_window_size [0 ] > 0 :
@@ -141,8 +141,8 @@ def _make_pair_wise_relative_positions(self):
141141 self .register_buffer ("relative_coords_table" , relative_coords_table , persistent = False )
142142
143143 # get pair-wise relative position index for each token inside the window
144- coords_h = torch .arange (self .window_size [0 ])
145- coords_w = torch .arange (self .window_size [1 ])
144+ coords_h = torch .arange (self .window_size [0 ], device = "cpu" )
145+ coords_w = torch .arange (self .window_size [1 ], device = "cpu" )
146146 coords = torch .stack (ndgrid (coords_h , coords_w )) # 2, Wh, Ww
147147 coords_flatten = torch .flatten (coords , 1 ) # 2, Wh*Ww
148148 relative_coords = coords_flatten [:, :, None ] - coords_flatten [:, None , :] # 2, Wh*Ww, Wh*Ww
@@ -293,7 +293,7 @@ def get_attn_mask(self, x: Optional[torch.Tensor] = None) -> Optional[torch.Tens
293293 if any (self .shift_size ):
294294 # calculate attention mask for SW-MSA
295295 if x is None :
296- img_mask = torch .zeros ((1 , * self .input_resolution , 1 )) # 1 H W 1
296+ img_mask = torch .zeros ((1 , * self .input_resolution , 1 ), device = "cpu" ) # 1 H W 1
297297 else :
298298 img_mask = torch .zeros ((1 , x .shape [1 ], x .shape [2 ], 1 ), dtype = x .dtype , device = x .device ) # 1 H W 1
299299 cnt = 0
0 commit comments