Skip to content

Commit cb98094

Browse files
committed
Merge branch 'main' into fast_model
2 parents 89d2952 + 81900a6 commit cb98094

31 files changed

+1574
-61
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ Model validation results can be found in the [results tables](results/README.md)
566566

567567
The official documentation can be found at https://huggingface.co/docs/hub/timm. Documentation contributions are welcome.
568568

569-
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
569+
[Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055-2/) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
570570

571571
[timmdocs](http://timm.fast.ai/) is an alternate set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
572572

tests/test_models.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,10 @@
5353
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5454
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5555
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
56-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'swiftformer',
57-
'starnet', 'shvit', 'fasternet',
56+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*',
57+
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
58+
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
59+
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer',
5860
]
5961

6062
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@@ -510,8 +512,9 @@ def test_model_forward_intermediates(model_name, batch_size):
510512
spatial_axis = get_spatial_dim(output_fmt)
511513
import math
512514

515+
inpt = torch.randn((batch_size, *input_size))
513516
output, intermediates = model.forward_intermediates(
514-
torch.randn((batch_size, *input_size)),
517+
inpt,
515518
output_fmt=output_fmt,
516519
)
517520
assert len(expected_channels) == len(intermediates)
@@ -523,6 +526,9 @@ def test_model_forward_intermediates(model_name, batch_size):
523526
assert o.shape[0] == batch_size
524527
assert not torch.isnan(o).any()
525528

529+
output2 = model.forward_features(inpt)
530+
assert torch.allclose(output, output2)
531+
526532

527533
def _create_fx_model(model, train=False):
528534
# This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode

timm/data/dataset_factory.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def create_dataset(
144144
use_train = split in _TRAIN_SYNONYM
145145
ds = QMNIST(train=use_train, **torch_kwargs)
146146
elif name == 'imagenet':
147+
torch_kwargs.pop('download')
147148
assert has_imagenet, 'Please update to a newer PyTorch and torchvision for ImageNet dataset.'
148149
if split in _EVAL_SYNONYM:
149150
split = 'val'

timm/models/convnext.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -452,29 +452,29 @@ def forward_intermediates(
452452
"""
453453
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
454454
intermediates = []
455-
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
455+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
456456

457457
# forward pass
458-
feat_idx = 0 # stem is index 0
459458
x = self.stem(x)
460-
if feat_idx in take_indices:
461-
intermediates.append(x)
462459

460+
last_idx = len(self.stages) - 1
463461
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
464462
stages = self.stages
465463
else:
466-
stages = self.stages[:max_index]
467-
for stage in stages:
468-
feat_idx += 1
464+
stages = self.stages[:max_index + 1]
465+
for feat_idx, stage in enumerate(stages):
469466
x = stage(x)
470467
if feat_idx in take_indices:
471-
# NOTE not bothering to apply norm_pre when norm=True as almost no models have it enabled
472-
intermediates.append(x)
468+
if norm and feat_idx == last_idx:
469+
intermediates.append(self.norm_pre(x))
470+
else:
471+
intermediates.append(x)
473472

474473
if intermediates_only:
475474
return intermediates
476475

477-
x = self.norm_pre(x)
476+
if feat_idx == last_idx:
477+
x = self.norm_pre(x)
478478

479479
return x, intermediates
480480

@@ -486,8 +486,8 @@ def prune_intermediate_layers(
486486
):
487487
""" Prune layers not required for specified intermediates.
488488
"""
489-
take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
490-
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
489+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
490+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
491491
if prune_norm:
492492
self.norm_pre = nn.Identity()
493493
if prune_head:

timm/models/davit.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# All rights reserved.
1313
# This source code is licensed under the MIT license
1414
from functools import partial
15-
from typing import Optional, Tuple
15+
from typing import List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -23,6 +23,7 @@
2323
from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn
2424
from timm.layers import NormMlpClassifierHead, ClassifierHead
2525
from ._builder import build_model_with_cfg
26+
from ._features import feature_take_indices
2627
from ._features_fx import register_notrace_function
2728
from ._manipulate import checkpoint_seq
2829
from ._registry import generate_default_cfgs, register_model
@@ -636,6 +637,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
636637
self.num_classes = num_classes
637638
self.head.reset(num_classes, global_pool)
638639

640+
def forward_intermediates(
641+
self,
642+
x: torch.Tensor,
643+
indices: Optional[Union[int, List[int]]] = None,
644+
norm: bool = False,
645+
stop_early: bool = False,
646+
output_fmt: str = 'NCHW',
647+
intermediates_only: bool = False,
648+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
649+
""" Forward features that returns intermediates.
650+
651+
Args:
652+
x: Input image tensor
653+
indices: Take last n blocks if int, all if None, select matching indices if sequence
654+
norm: Apply norm layer to compatible intermediates
655+
stop_early: Stop iterating over blocks when last desired intermediate hit
656+
output_fmt: Shape of intermediate feature outputs
657+
intermediates_only: Only return intermediate features
658+
Returns:
659+
660+
"""
661+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
662+
intermediates = []
663+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
664+
665+
# forward pass
666+
x = self.stem(x)
667+
last_idx = len(self.stages) - 1
668+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
669+
stages = self.stages
670+
else:
671+
stages = self.stages[:max_index + 1]
672+
673+
for feat_idx, stage in enumerate(stages):
674+
x = stage(x)
675+
if feat_idx in take_indices:
676+
if norm and feat_idx == last_idx:
677+
x_inter = self.norm_pre(x) # applying final norm to last intermediate
678+
else:
679+
x_inter = x
680+
intermediates.append(x_inter)
681+
682+
if intermediates_only:
683+
return intermediates
684+
685+
if feat_idx == last_idx:
686+
x = self.norm_pre(x)
687+
688+
return x, intermediates
689+
690+
def prune_intermediate_layers(
691+
self,
692+
indices: Union[int, List[int]] = 1,
693+
prune_norm: bool = False,
694+
prune_head: bool = True,
695+
):
696+
""" Prune layers not required for specified intermediates.
697+
"""
698+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
699+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
700+
if prune_norm:
701+
self.norm_pre = nn.Identity()
702+
if prune_head:
703+
self.reset_classifier(0, '')
704+
return take_indices
705+
639706
def forward_features(self, x):
640707
x = self.stem(x)
641708
if self.grad_checkpointing and not torch.jit.is_scripting():

timm/models/edgenext.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010
import math
1111
from functools import partial
12-
from typing import Optional, Tuple
12+
from typing import List, Optional, Tuple, Union
1313

1414
import torch
1515
import torch.nn.functional as F
@@ -19,6 +19,7 @@
1919
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
2020
NormMlpClassifierHead, ClassifierHead
2121
from ._builder import build_model_with_cfg
22+
from ._features import feature_take_indices
2223
from ._features_fx import register_notrace_module
2324
from ._manipulate import named_apply, checkpoint_seq
2425
from ._registry import register_model, generate_default_cfgs
@@ -418,6 +419,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
418419
self.num_classes = num_classes
419420
self.head.reset(num_classes, global_pool)
420421

422+
def forward_intermediates(
423+
self,
424+
x: torch.Tensor,
425+
indices: Optional[Union[int, List[int]]] = None,
426+
norm: bool = False,
427+
stop_early: bool = False,
428+
output_fmt: str = 'NCHW',
429+
intermediates_only: bool = False,
430+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
431+
""" Forward features that returns intermediates.
432+
433+
Args:
434+
x: Input image tensor
435+
indices: Take last n blocks if int, all if None, select matching indices if sequence
436+
norm: Apply norm layer to compatible intermediates
437+
stop_early: Stop iterating over blocks when last desired intermediate hit
438+
output_fmt: Shape of intermediate feature outputs
439+
intermediates_only: Only return intermediate features
440+
Returns:
441+
442+
"""
443+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
444+
intermediates = []
445+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
446+
447+
# forward pass
448+
x = self.stem(x)
449+
last_idx = len(self.stages) - 1
450+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
451+
stages = self.stages
452+
else:
453+
stages = self.stages[:max_index + 1]
454+
455+
for feat_idx, stage in enumerate(stages):
456+
x = stage(x)
457+
if feat_idx in take_indices:
458+
if norm and feat_idx == last_idx:
459+
x_inter = self.norm_pre(x) # applying final norm to last intermediate
460+
else:
461+
x_inter = x
462+
intermediates.append(x_inter)
463+
464+
if intermediates_only:
465+
return intermediates
466+
467+
if feat_idx == last_idx:
468+
x = self.norm_pre(x)
469+
470+
return x, intermediates
471+
472+
def prune_intermediate_layers(
473+
self,
474+
indices: Union[int, List[int]] = 1,
475+
prune_norm: bool = False,
476+
prune_head: bool = True,
477+
):
478+
""" Prune layers not required for specified intermediates.
479+
"""
480+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
481+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
482+
if prune_norm:
483+
self.norm_pre = nn.Identity()
484+
if prune_head:
485+
self.reset_classifier(0, '')
486+
return take_indices
487+
421488
def forward_features(self, x):
422489
x = self.stem(x)
423490
x = self.stages(x)

timm/models/efficientformer_v2.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717
import math
1818
from functools import partial
19-
from typing import Dict, Optional
19+
from typing import Dict, List, Optional, Tuple, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -25,6 +25,7 @@
2525
from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct
2626
from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid
2727
from ._builder import build_model_with_cfg
28+
from ._features import feature_take_indices
2829
from ._manipulate import checkpoint_seq
2930
from ._registry import generate_default_cfgs, register_model
3031

@@ -625,6 +626,73 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
625626
def set_distilled_training(self, enable=True):
626627
self.distilled_training = enable
627628

629+
def forward_intermediates(
630+
self,
631+
x: torch.Tensor,
632+
indices: Optional[Union[int, List[int]]] = None,
633+
norm: bool = False,
634+
stop_early: bool = False,
635+
output_fmt: str = 'NCHW',
636+
intermediates_only: bool = False,
637+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
638+
""" Forward features that returns intermediates.
639+
640+
Args:
641+
x: Input image tensor
642+
indices: Take last n blocks if int, all if None, select matching indices if sequence
643+
norm: Apply norm layer to compatible intermediates
644+
stop_early: Stop iterating over blocks when last desired intermediate hit
645+
output_fmt: Shape of intermediate feature outputs
646+
intermediates_only: Only return intermediate features
647+
Returns:
648+
649+
"""
650+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
651+
intermediates = []
652+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
653+
654+
# forward pass
655+
x = self.stem(x)
656+
657+
last_idx = len(self.stages) - 1
658+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
659+
stages = self.stages
660+
else:
661+
stages = self.stages[:max_index + 1]
662+
663+
for feat_idx, stage in enumerate(stages):
664+
x = stage(x)
665+
if feat_idx in take_indices:
666+
if feat_idx == last_idx:
667+
x_inter = self.norm(x) if norm else x
668+
intermediates.append(x_inter)
669+
else:
670+
intermediates.append(x)
671+
672+
if intermediates_only:
673+
return intermediates
674+
675+
if feat_idx == last_idx:
676+
x = self.norm(x)
677+
678+
return x, intermediates
679+
680+
def prune_intermediate_layers(
681+
self,
682+
indices: Union[int, List[int]] = 1,
683+
prune_norm: bool = False,
684+
prune_head: bool = True,
685+
):
686+
""" Prune layers not required for specified intermediates.
687+
"""
688+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
689+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
690+
if prune_norm:
691+
self.norm = nn.Identity()
692+
if prune_head:
693+
self.reset_classifier(0, '')
694+
return take_indices
695+
628696
def forward_features(self, x):
629697
x = self.stem(x)
630698
x = self.stages(x)

0 commit comments

Comments
 (0)