Skip to content

Commit 1b50b15

Browse files
authored
Merge pull request #2092 from huggingface/mesa_ema
ModelEMAV3 + MESA experiments
2 parents 88889de + 47c9bc4 commit 1b50b15

File tree

13 files changed

+1065
-119
lines changed

13 files changed

+1065
-119
lines changed

timm/data/transforms.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,12 @@ def __call__(self, pil_img):
3232

3333

3434
class ToTensor:
35-
35+
""" ToTensor with no rescaling of values"""
3636
def __init__(self, dtype=torch.float32):
3737
self.dtype = dtype
3838

3939
def __call__(self, pil_img):
40-
np_img = np.array(pil_img, dtype=np.uint8)
41-
if np_img.ndim < 3:
42-
np_img = np.expand_dims(np_img, axis=-1)
43-
np_img = np.rollaxis(np_img, 2) # HWC to CHW
44-
return torch.from_numpy(np_img).to(dtype=self.dtype)
40+
return F.pil_to_tensor(pil_img).to(dtype=self.dtype)
4541

4642

4743
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in

timm/layers/classifier.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,10 @@ def __init__(
180180
self.drop = nn.Dropout(drop_rate)
181181
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
182182

183-
def reset(self, num_classes, global_pool=None):
184-
if global_pool is not None:
185-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
186-
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
183+
def reset(self, num_classes, pool_type=None):
184+
if pool_type is not None:
185+
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
186+
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
187187
self.use_conv = self.global_pool.is_identity()
188188
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
189189
if self.hidden_size:

timm/layers/create_act.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
148148
return _ACT_LAYER_DEFAULT[name]
149149

150150

151-
def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
151+
def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
152152
act_layer = get_act_layer(name)
153153
if act_layer is None:
154154
return None

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from .mvitv2 import *
4040
from .nasnet import *
4141
from .nest import *
42+
from .nextvit import *
4243
from .nfnet import *
4344
from .pit import *
4445
from .pnasnet import *

timm/models/davit.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,17 @@ def _init_weights(self, m):
547547
if isinstance(m, nn.Linear) and m.bias is not None:
548548
nn.init.constant_(m.bias, 0)
549549

550+
@torch.jit.ignore
551+
def group_matcher(self, coarse=False):
552+
return dict(
553+
stem=r'^stem', # stem and embed
554+
blocks=r'^stages\.(\d+)' if coarse else [
555+
(r'^stages\.(\d+).downsample', (0,)),
556+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
557+
(r'^norm_pre', (99999,)),
558+
]
559+
)
560+
550561
@torch.jit.ignore
551562
def set_grad_checkpointing(self, enable=True):
552563
self.grad_checkpointing = enable
@@ -558,7 +569,7 @@ def get_classifier(self):
558569
return self.head.fc
559570

560571
def reset_classifier(self, num_classes, global_pool=None):
561-
self.head.reset(num_classes, global_pool=global_pool)
572+
self.head.reset(num_classes, global_pool)
562573

563574
def forward_features(self, x):
564575
x = self.stem(x)

0 commit comments

Comments
 (0)