Skip to content

Commit 715519a

Browse files
committed
Rethink name of patch embed grid info
1 parent b2c305c commit 715519a

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

timm/models/coat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def forward_features(self, x0):
490490

491491
# Serial blocks 1.
492492
x1 = self.patch_embed1(x0)
493-
H1, W1 = self.patch_embed1.out_size
493+
H1, W1 = self.patch_embed1.grid_size
494494
x1 = self.insert_cls(x1, self.cls_token1)
495495
for blk in self.serial_blocks1:
496496
x1 = blk(x1, size=(H1, W1))
@@ -499,7 +499,7 @@ def forward_features(self, x0):
499499

500500
# Serial blocks 2.
501501
x2 = self.patch_embed2(x1_nocls)
502-
H2, W2 = self.patch_embed2.out_size
502+
H2, W2 = self.patch_embed2.grid_size
503503
x2 = self.insert_cls(x2, self.cls_token2)
504504
for blk in self.serial_blocks2:
505505
x2 = blk(x2, size=(H2, W2))
@@ -508,7 +508,7 @@ def forward_features(self, x0):
508508

509509
# Serial blocks 3.
510510
x3 = self.patch_embed3(x2_nocls)
511-
H3, W3 = self.patch_embed3.out_size
511+
H3, W3 = self.patch_embed3.grid_size
512512
x3 = self.insert_cls(x3, self.cls_token3)
513513
for blk in self.serial_blocks3:
514514
x3 = blk(x3, size=(H3, W3))
@@ -517,7 +517,7 @@ def forward_features(self, x0):
517517

518518
# Serial blocks 4.
519519
x4 = self.patch_embed4(x3_nocls)
520-
H4, W4 = self.patch_embed4.out_size
520+
H4, W4 = self.patch_embed4.grid_size
521521
x4 = self.insert_cls(x4, self.cls_token4)
522522
for blk in self.serial_blocks4:
523523
x4 = blk(x4, size=(H4, W4))

timm/models/layers/patch_embed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_
2121
patch_size = to_2tuple(patch_size)
2222
self.img_size = img_size
2323
self.patch_size = patch_size
24-
self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
25-
self.num_patches = self.out_size[0] * self.out_size[1]
24+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
25+
self.num_patches = self.grid_size[0] * self.grid_size[1]
2626

2727
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
2828
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

timm/models/swin_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
467467
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
468468
norm_layer=norm_layer if self.patch_norm else None)
469469
num_patches = self.patch_embed.num_patches
470-
self.patch_grid = self.patch_embed.out_size
470+
self.patch_grid = self.patch_embed.grid_size
471471

472472
# absolute position embedding
473473
if self.ape:

0 commit comments

Comments
 (0)