Skip to content

Commit 88889de

Browse files
committed
Fix meshgrid deprecation warnings and backward compat with explicit 'ndgrid' and 'meshgrid' fn w/o indexing arg
1 parent fa247fd commit 88889de

File tree

13 files changed

+110
-51
lines changed

13 files changed

+110
-51
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
2525
from .gather_excite import GatherExcite
2626
from .global_context import GlobalContext
27+
from .grid import ndgrid, meshgrid
2728
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
2829
from .inplace_abn import InplaceAbn
2930
from .linear import Linear

timm/layers/drop.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,18 @@
1818
import torch.nn as nn
1919
import torch.nn.functional as F
2020

21+
from .grid import ndgrid
22+
2123

2224
def drop_block_2d(
23-
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
24-
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
25+
x,
26+
drop_prob: float = 0.1,
27+
block_size: int = 7,
28+
gamma_scale: float = 1.0,
29+
with_noise: bool = False,
30+
inplace: bool = False,
31+
batchwise: bool = False
32+
):
2533
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
2634
2735
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
@@ -35,7 +43,7 @@ def drop_block_2d(
3543
(W - block_size + 1) * (H - block_size + 1))
3644

3745
# Forces the block to be inside the feature map.
38-
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
46+
w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
3947
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
4048
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
4149
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
@@ -68,8 +76,13 @@ def drop_block_2d(
6876

6977

7078
def drop_block_fast_2d(
71-
x: torch.Tensor, drop_prob: float = 0.1, block_size: int = 7,
72-
gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False):
79+
x: torch.Tensor,
80+
drop_prob: float = 0.1,
81+
block_size: int = 7,
82+
gamma_scale: float = 1.0,
83+
with_noise: bool = False,
84+
inplace: bool = False,
85+
):
7386
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
7487
7588
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid

timm/layers/grid.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Tuple
2+
3+
import torch
4+
5+
6+
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
7+
"""generate N-D grid in dimension order.
8+
9+
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
10+
11+
That is, the statement
12+
[X1,X2,X3] = ndgrid(x1,x2,x3)
13+
14+
produces the same result as
15+
16+
[X2,X1,X3] = meshgrid(x2,x1,x3)
17+
18+
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
19+
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
20+
21+
"""
22+
try:
23+
return torch.meshgrid(*tensors, indexing='ij')
24+
except TypeError:
25+
# old PyTorch < 1.10 will follow this path as it does not have indexing arg,
26+
# the old behaviour of meshgrid was 'ij'
27+
return torch.meshgrid(*tensors)
28+
29+
30+
def meshgrid(*tensors) -> Tuple[torch.Tensor, ...]:
31+
"""generate N-D grid in spatial dim order.
32+
33+
The meshgrid function is similar to ndgrid except that the order of the
34+
first two input and output arguments is switched.
35+
36+
That is, the statement
37+
38+
[X,Y,Z] = meshgrid(x,y,z)
39+
produces the same result as
40+
41+
[Y,X,Z] = ndgrid(y,x,z)
42+
Because of this, meshgrid is better suited to problems in two- or three-dimensional Cartesian space,
43+
while ndgrid is better suited to multidimensional problems that aren't spatially based.
44+
"""
45+
46+
# NOTE: this will throw in PyTorch < 1.10 as meshgrid did not support indexing arg or have
47+
# capability of generating grid in xy order before then.
48+
return torch.meshgrid(*tensors, indexing='xy')
49+

timm/layers/lambda_layer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
from torch import nn
2525
import torch.nn.functional as F
2626

27+
from .grid import ndgrid
2728
from .helpers import to_2tuple, make_divisible
2829
from .weight_init import trunc_normal_
2930

3031

3132
def rel_pos_indices(size):
3233
size = to_2tuple(size)
33-
pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
34+
pos = torch.stack(ndgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
3435
rel_pos = pos[:, None, :] - pos[:, :, None]
3536
rel_pos[0] += size[0] - 1
3637
rel_pos[1] += size[1] - 1

timm/layers/pos_embed_rel.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212

13+
from .grid import ndgrid
1314
from .interpolate import RegularGridInterpolator
1415
from .mlp import Mlp
1516
from .weight_init import trunc_normal_
@@ -26,12 +27,7 @@ def gen_relative_position_index(
2627
# get pair-wise relative position index for each token inside the window
2728
assert k_size is None, 'Different q & k sizes not currently supported' # FIXME
2829

29-
coords = torch.stack(
30-
torch.meshgrid([
31-
torch.arange(q_size[0]),
32-
torch.arange(q_size[1])
33-
])
34-
).flatten(1) # 2, Wh, Ww
30+
coords = torch.stack(ndgrid(torch.arange(q_size[0]), torch.arange(q_size[1]))).flatten(1) # 2, Wh, Ww
3531
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
3632
relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
3733
relative_coords[:, :, 0] += q_size[0] - 1 # shift to start from 0
@@ -42,16 +38,16 @@ def gen_relative_position_index(
4238
# else:
4339
# # FIXME different q vs k sizes is a WIP, need to better offset the two grids?
4440
# q_coords = torch.stack(
45-
# torch.meshgrid([
41+
# ndgrid(
4642
# torch.arange(q_size[0]),
4743
# torch.arange(q_size[1])
48-
# ])
44+
# )
4945
# ).flatten(1) # 2, Wh, Ww
5046
# k_coords = torch.stack(
51-
# torch.meshgrid([
47+
# ndgrid(
5248
# torch.arange(k_size[0]),
5349
# torch.arange(k_size[1])
54-
# ])
50+
# )
5551
# ).flatten(1)
5652
# relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
5753
# relative_coords = relative_coords.permute(1, 2, 0) # Qh*Qw, Kh*Kw, 2
@@ -232,7 +228,7 @@ def _calc(src, dst):
232228
tx = dst_size[1] // 2.0
233229
dy = torch.arange(-ty, ty + 0.1, 1.0)
234230
dx = torch.arange(-tx, tx + 0.1, 1.0)
235-
dyx = torch.meshgrid([dy, dx])
231+
dyx = ndgrid(dy, dx)
236232
# print("Target positions = %s" % str(dx))
237233

238234
all_rel_pos_bias = []
@@ -313,7 +309,7 @@ def gen_relative_log_coords(
313309
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
314310
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0]).to(torch.float32)
315311
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1]).to(torch.float32)
316-
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
312+
relative_coords_table = torch.stack(ndgrid(relative_coords_h, relative_coords_w))
317313
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
318314
if mode == 'swin':
319315
if pretrained_win_size[0] > 0:

timm/layers/pos_embed_sincos.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from torch import nn as nn
1010

11+
from .grid import ndgrid
1112
from .trace_utils import _assert
1213

1314

@@ -64,10 +65,10 @@ def build_sincos2d_pos_embed(
6465

6566
if reverse_coord:
6667
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
67-
grid = torch.stack(torch.meshgrid(
68-
[torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
69-
for s in feat_shape])
70-
).flatten(1).transpose(0, 1)
68+
grid = torch.stack(ndgrid([
69+
torch.arange(s, device=device, dtype=torch.int64).to(torch.float32)
70+
for s in feat_shape
71+
])).flatten(1).transpose(0, 1)
7172
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
7273
# FIXME add support for unflattened spatial dim?
7374

@@ -137,7 +138,7 @@ def build_fourier_pos_embed(
137138
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
138139
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
139140

140-
grid = torch.stack(torch.meshgrid(t), dim=-1)
141+
grid = torch.stack(ndgrid(t), dim=-1)
141142
grid = grid.unsqueeze(-1)
142143
pos = grid * bands
143144

timm/models/beit.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
5050
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
51-
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table
51+
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
5252

5353

5454
from ._builder import build_model_with_cfg
@@ -63,9 +63,7 @@ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
6363
# cls to token & token 2 cls & cls to cls
6464
# get pair-wise relative position index for each token inside the window
6565
window_area = window_size[0] * window_size[1]
66-
coords = torch.stack(torch.meshgrid(
67-
[torch.arange(window_size[0]),
68-
torch.arange(window_size[1])])) # 2, Wh, Ww
66+
coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
6967
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
7068
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
7169
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2

timm/models/efficientformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import torch.nn as nn
1919

2020
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
21-
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp
21+
from timm.layers import DropPath, trunc_normal_, to_2tuple, Mlp, ndgrid
2222
from ._builder import build_model_with_cfg
2323
from ._manipulate import checkpoint_seq
2424
from ._registry import generate_default_cfgs, register_model
@@ -63,7 +63,7 @@ def __init__(
6363
self.proj = nn.Linear(self.val_attn_dim, dim)
6464

6565
resolution = to_2tuple(resolution)
66-
pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
66+
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
6767
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
6868
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
6969
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))

timm/models/efficientformer_v2.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2525
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
26-
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple
26+
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
2727
from ._builder import build_model_with_cfg
2828
from ._manipulate import checkpoint_seq
2929
from ._registry import generate_default_cfgs, register_model
@@ -129,7 +129,7 @@ def __init__(
129129
self.act = act_layer()
130130
self.proj = ConvNorm(self.dh, dim, 1)
131131

132-
pos = torch.stack(torch.meshgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
132+
pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
133133
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
134134
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
135135
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N))
@@ -231,12 +231,11 @@ def __init__(
231231
self.proj = ConvNorm(self.dh, self.out_dim, 1)
232232

233233
self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N))
234-
k_pos = torch.stack(torch.meshgrid(torch.arange(
235-
self.resolution[0]),
236-
torch.arange(self.resolution[1]))).flatten(1)
237-
q_pos = torch.stack(torch.meshgrid(
234+
k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
235+
q_pos = torch.stack(ndgrid(
238236
torch.arange(0, self.resolution[0], step=2),
239-
torch.arange(0, self.resolution[1], step=2))).flatten(1)
237+
torch.arange(0, self.resolution[1], step=2)
238+
)).flatten(1)
240239
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
241240
rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
242241
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)

timm/models/levit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import torch.nn as nn
3232

3333
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
34-
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_
34+
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
3535
from ._builder import build_model_with_cfg
3636
from ._manipulate import checkpoint_seq
3737
from ._registry import generate_default_cfgs, register_model
@@ -194,7 +194,7 @@ def __init__(
194194
]))
195195

196196
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
197-
pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
197+
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
198198
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
199199
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
200200
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
@@ -290,10 +290,11 @@ def __init__(
290290
]))
291291

292292
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
293-
k_pos = torch.stack(torch.meshgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
294-
q_pos = torch.stack(torch.meshgrid(
293+
k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
294+
q_pos = torch.stack(ndgrid(
295295
torch.arange(0, resolution[0], step=stride),
296-
torch.arange(0, resolution[1], step=stride))).flatten(1)
296+
torch.arange(0, resolution[1], step=stride)
297+
)).flatten(1)
297298
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
298299
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
299300
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)

0 commit comments

Comments
 (0)