|
7 | 7 | Hacked together by / Copyright 2019, Ross Wightman |
8 | 8 | """ |
9 | 9 | from functools import partial |
10 | | -from typing import Callable, List, Optional, Tuple |
| 10 | +from typing import Callable, List, Optional, Tuple, Union |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | import torch.nn as nn |
|
20 | 20 | from ._efficientnet_blocks import SqueezeExcite |
21 | 21 | from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \ |
22 | 22 | round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT |
23 | | -from ._features import FeatureInfo, FeatureHooks |
| 23 | +from ._features import FeatureInfo, FeatureHooks, feature_take_indices |
24 | 24 | from ._manipulate import checkpoint_seq |
25 | 25 | from ._registry import generate_default_cfgs, register_model, register_model_deprecations |
26 | 26 |
|
@@ -109,6 +109,7 @@ def __init__( |
109 | 109 | ) |
110 | 110 | self.blocks = nn.Sequential(*builder(stem_size, block_args)) |
111 | 111 | self.feature_info = builder.features |
| 112 | + self.stage_ends = [f['stage'] for f in self.feature_info] |
112 | 113 | head_chs = builder.in_chs |
113 | 114 |
|
114 | 115 | # Head + Pooling |
@@ -150,6 +151,84 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'): |
150 | 151 | self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled |
151 | 152 | self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
152 | 153 |
|
| 154 | + def forward_intermediates( |
| 155 | + self, |
| 156 | + x: torch.Tensor, |
| 157 | + *, |
| 158 | + indices: Union[int, List[int], Tuple[int]] = None, |
| 159 | + norm: bool = False, |
| 160 | + stop_early: bool = False, |
| 161 | + output_fmt: str = 'NCHW', |
| 162 | + intermediates_only: bool = False, |
| 163 | + extra_blocks: bool = False, |
| 164 | + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 165 | + """ Forward features that returns intermediates. |
| 166 | +
|
| 167 | + Args: |
| 168 | + x: Input image tensor |
| 169 | + indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 170 | + norm: Apply norm layer to compatible intermediates |
| 171 | + stop_early: Stop iterating over blocks when last desired intermediate hit |
| 172 | + output_fmt: Shape of intermediate feature outputs |
| 173 | + intermediates_only: Only return intermediate features |
| 174 | + extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info |
| 175 | + Returns: |
| 176 | +
|
| 177 | + """ |
| 178 | + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 179 | + if stop_early: |
| 180 | + assert intermediates_only, 'Must use intermediates_only for early stopping.' |
| 181 | + intermediates = [] |
| 182 | + if extra_blocks: |
| 183 | + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) |
| 184 | + else: |
| 185 | + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) |
| 186 | + print(take_indices, self.stage_ends) |
| 187 | + take_indices = [self.stage_ends[i] for i in take_indices] |
| 188 | + max_index = self.stage_ends[max_index] |
| 189 | + # forward pass |
| 190 | + feat_idx = 0 # stem is index 0 |
| 191 | + x = self.conv_stem(x) |
| 192 | + x = self.bn1(x) |
| 193 | + if feat_idx in take_indices: |
| 194 | + intermediates.append(x) |
| 195 | + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 196 | + blocks = self.blocks |
| 197 | + else: |
| 198 | + blocks = self.blocks[:max_index] |
| 199 | + for blk in blocks: |
| 200 | + feat_idx += 1 |
| 201 | + x = blk(x) |
| 202 | + if feat_idx in take_indices: |
| 203 | + intermediates.append(x) |
| 204 | + |
| 205 | + if intermediates_only: |
| 206 | + return intermediates |
| 207 | + |
| 208 | + return x, intermediates |
| 209 | + |
| 210 | + def prune_intermediate_layers( |
| 211 | + self, |
| 212 | + indices: Union[int, List[int], Tuple[int]] = 1, |
| 213 | + prune_norm: bool = False, |
| 214 | + prune_head: bool = True, |
| 215 | + extra_blocks: bool = False, |
| 216 | + ): |
| 217 | + """ Prune layers not required for specified intermediates. |
| 218 | + """ |
| 219 | + if extra_blocks: |
| 220 | + take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices) |
| 221 | + else: |
| 222 | + take_indices, max_index = feature_take_indices(len(self.stage_ends), indices) |
| 223 | + max_index = self.stage_ends[max_index] |
| 224 | + self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0 |
| 225 | + if max_index < len(self.blocks): |
| 226 | + self.conv_head = nn.Identity() |
| 227 | + if prune_head: |
| 228 | + self.conv_head = nn.Identity() |
| 229 | + self.reset_classifier(0, '') |
| 230 | + return take_indices |
| 231 | + |
153 | 232 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
154 | 233 | x = self.conv_stem(x) |
155 | 234 | x = self.bn1(x) |
@@ -288,7 +367,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV |
288 | 367 | model_cls = MobileNetV3 |
289 | 368 | kwargs_filter = None |
290 | 369 | if kwargs.pop('features_only', False): |
291 | | - if 'feature_cfg' in kwargs: |
| 370 | + if 'feature_cfg' in kwargs or 'feature_cls' in kwargs: |
292 | 371 | features_mode = 'cfg' |
293 | 372 | else: |
294 | 373 | kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') |
|
0 commit comments