|
12 | 12 | # All rights reserved. |
13 | 13 | # This source code is licensed under the MIT license |
14 | 14 | from functools import partial |
15 | | -from typing import Optional, Tuple |
| 15 | +from typing import List, Optional, Tuple, Union |
16 | 16 |
|
17 | 17 | import torch |
18 | 18 | import torch.nn as nn |
|
23 | 23 | from timm.layers import DropPath, to_2tuple, trunc_normal_, Mlp, LayerNorm2d, get_norm_layer, use_fused_attn |
24 | 24 | from timm.layers import NormMlpClassifierHead, ClassifierHead |
25 | 25 | from ._builder import build_model_with_cfg |
| 26 | +from ._features import feature_take_indices |
26 | 27 | from ._features_fx import register_notrace_function |
27 | 28 | from ._manipulate import checkpoint_seq |
28 | 29 | from ._registry import generate_default_cfgs, register_model |
@@ -636,6 +637,72 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): |
636 | 637 | self.num_classes = num_classes |
637 | 638 | self.head.reset(num_classes, global_pool) |
638 | 639 |
|
| 640 | + def forward_intermediates( |
| 641 | + self, |
| 642 | + x: torch.Tensor, |
| 643 | + indices: Optional[Union[int, List[int]]] = None, |
| 644 | + norm: bool = False, |
| 645 | + stop_early: bool = False, |
| 646 | + output_fmt: str = 'NCHW', |
| 647 | + intermediates_only: bool = False, |
| 648 | + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 649 | + """ Forward features that returns intermediates. |
| 650 | +
|
| 651 | + Args: |
| 652 | + x: Input image tensor |
| 653 | + indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 654 | + norm: Apply norm layer to compatible intermediates |
| 655 | + stop_early: Stop iterating over blocks when last desired intermediate hit |
| 656 | + output_fmt: Shape of intermediate feature outputs |
| 657 | + intermediates_only: Only return intermediate features |
| 658 | + Returns: |
| 659 | +
|
| 660 | + """ |
| 661 | + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 662 | + intermediates = [] |
| 663 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 664 | + |
| 665 | + # forward pass |
| 666 | + x = self.stem(x) |
| 667 | + last_idx = len(self.stages) - 1 |
| 668 | + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 669 | + stages = self.stages |
| 670 | + else: |
| 671 | + stages = self.stages[:max_index + 1] |
| 672 | + |
| 673 | + for feat_idx, stage in enumerate(stages): |
| 674 | + x = stage(x) |
| 675 | + if feat_idx in take_indices: |
| 676 | + if norm and feat_idx == last_idx: |
| 677 | + x_inter = self.norm_pre(x) # applying final norm to last intermediate |
| 678 | + else: |
| 679 | + x_inter = x |
| 680 | + intermediates.append(x_inter) |
| 681 | + |
| 682 | + if intermediates_only: |
| 683 | + return intermediates |
| 684 | + |
| 685 | + if feat_idx == last_idx: |
| 686 | + x = self.norm_pre(x) |
| 687 | + |
| 688 | + return x, intermediates |
| 689 | + |
| 690 | + def prune_intermediate_layers( |
| 691 | + self, |
| 692 | + indices: Union[int, List[int]] = 1, |
| 693 | + prune_norm: bool = False, |
| 694 | + prune_head: bool = True, |
| 695 | + ): |
| 696 | + """ Prune layers not required for specified intermediates. |
| 697 | + """ |
| 698 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 699 | + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 |
| 700 | + if prune_norm: |
| 701 | + self.norm_pre = nn.Identity() |
| 702 | + if prune_head: |
| 703 | + self.reset_classifier(0, '') |
| 704 | + return take_indices |
| 705 | + |
639 | 706 | def forward_features(self, x): |
640 | 707 | x = self.stem(x) |
641 | 708 | if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
0 commit comments