1010import torch .nn as nn
1111import torch .nn .functional as F
1212
13+ from .grid import ndgrid
1314from .interpolate import RegularGridInterpolator
1415from .mlp import Mlp
1516from .weight_init import trunc_normal_
@@ -26,12 +27,7 @@ def gen_relative_position_index(
2627 # get pair-wise relative position index for each token inside the window
2728 assert k_size is None , 'Different q & k sizes not currently supported' # FIXME
2829
29- coords = torch .stack (
30- torch .meshgrid ([
31- torch .arange (q_size [0 ]),
32- torch .arange (q_size [1 ])
33- ])
34- ).flatten (1 ) # 2, Wh, Ww
30+ coords = torch .stack (ndgrid (torch .arange (q_size [0 ]), torch .arange (q_size [1 ]))).flatten (1 ) # 2, Wh, Ww
3531 relative_coords = coords [:, :, None ] - coords [:, None , :] # 2, Wh*Ww, Wh*Ww
3632 relative_coords = relative_coords .permute (1 , 2 , 0 ) # Qh*Qw, Kh*Kw, 2
3733 relative_coords [:, :, 0 ] += q_size [0 ] - 1 # shift to start from 0
@@ -42,16 +38,16 @@ def gen_relative_position_index(
4238 # else:
4339 # # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
4440 # q_coords = torch.stack(
45- # torch.meshgrid([
41+ # ndgrid(
4642 # torch.arange(q_size[0]),
4743 # torch.arange(q_size[1])
48- # ] )
44+ # )
4945 # ).flatten(1) # 2, Wh, Ww
5046 # k_coords = torch.stack(
51- # torch.meshgrid([
47+ # ndgrid(
5248 # torch.arange(k_size[0]),
5349 # torch.arange(k_size[1])
54- # ] )
50+ # )
5551 # ).flatten(1)
5652 # relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
5753 # relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
@@ -232,7 +228,7 @@ def _calc(src, dst):
232228 tx = dst_size [1 ] // 2.0
233229 dy = torch .arange (- ty , ty + 0.1 , 1.0 )
234230 dx = torch .arange (- tx , tx + 0.1 , 1.0 )
235- dyx = torch . meshgrid ([ dy , dx ] )
231+ dyx = ndgrid ( dy , dx )
236232 # print("Target positions = %s" % str(dx))
237233
238234 all_rel_pos_bias = []
@@ -313,7 +309,7 @@ def gen_relative_log_coords(
313309 # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
314310 relative_coords_h = torch .arange (- (win_size [0 ] - 1 ), win_size [0 ]).to (torch .float32 )
315311 relative_coords_w = torch .arange (- (win_size [1 ] - 1 ), win_size [1 ]).to (torch .float32 )
316- relative_coords_table = torch .stack (torch . meshgrid ([ relative_coords_h , relative_coords_w ] ))
312+ relative_coords_table = torch .stack (ndgrid ( relative_coords_h , relative_coords_w ))
317313 relative_coords_table = relative_coords_table .permute (1 , 2 , 0 ).contiguous () # 2*Wh-1, 2*Ww-1, 2
318314 if mode == 'swin' :
319315 if pretrained_win_size [0 ] > 0 :
0 commit comments