|
10 | 10 |
|
11 | 11 | from collections import OrderedDict |
12 | 12 | from functools import partial |
13 | | -from typing import List, Final, Optional, Tuple |
| 13 | +from typing import List, Final, Optional, Tuple, Union |
14 | 14 |
|
15 | 15 | import torch |
16 | 16 | import torch.nn as nn |
@@ -379,17 +379,17 @@ def __init__( |
379 | 379 | if self.cls_token is not None: |
380 | 380 | trunc_normal_(self.cls_token, std=.02) |
381 | 381 |
|
382 | | - def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 382 | + def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
383 | 383 | x = self.conv_embed(x) |
384 | 384 | x = self.embed_drop(x) |
385 | 385 |
|
386 | 386 | cls_token = self.embed_drop( |
387 | 387 | self.cls_token.expand(x.shape[0], -1, -1) |
388 | 388 | ) if self.cls_token is not None else None |
389 | | - for block in self.blocks: # technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tensor |
| 389 | + for block in self.blocks: # TODO technically possible to exploit nn.Sequential's untyped intermediate results if each block takes in a tuple |
390 | 390 | x, cls_token = block(x, cls_token) |
391 | 391 |
|
392 | | - return x, cls_token |
| 392 | + return (x, cls_token) if self.cls_token is not None else x |
393 | 393 |
|
394 | 394 | class CvT(nn.Module): |
395 | 395 | def __init__( |
@@ -429,8 +429,8 @@ def __init__( |
429 | 429 | assert num_stages == len(embed_padding) == len(num_heads) == len(use_cls_token) |
430 | 430 | self.num_classes = num_classes |
431 | 431 | self.num_features = dims[-1] |
| 432 | + self.feature_info = [] |
432 | 433 |
|
433 | | - # FIXME only on last stage, no need for tuple |
434 | 434 | self.use_cls_token = use_cls_token[-1] |
435 | 435 |
|
436 | 436 | dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] |
@@ -473,28 +473,27 @@ def __init__( |
473 | 473 | ) |
474 | 474 | in_chs = dim |
475 | 475 | stages.append(stage) |
476 | | - self.stages = nn.ModuleList(stages) |
| 476 | + self.feature_info += [dict(num_chs=dim, reduction=2, module=f'stages.{stage_idx}')] |
| 477 | + self.stages = nn.Sequential(*stages) |
477 | 478 |
|
478 | 479 | self.norm = norm_layer(dims[-1]) |
479 | 480 | self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity() |
480 | 481 |
|
481 | 482 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
482 | 483 |
|
483 | 484 | for stage in self.stages: |
484 | | - x, cls_token = stage(x) |
| 485 | + x = stage(x) |
485 | 486 |
|
486 | 487 |
|
487 | 488 | if self.use_cls_token: |
488 | | - return self.head(self.norm(cls_token.flatten(1))) |
| 489 | + return self.head(self.norm(x[1].flatten(1))) |
489 | 490 | else: |
490 | 491 | return self.head(self.norm(x.mean(dim=(2,3)))) |
491 | 492 |
|
492 | 493 |
|
493 | 494 |
|
494 | 495 | def checkpoint_filter_fn(state_dict, model): |
495 | 496 | """ Remap MSFT checkpoints -> timm """ |
496 | | - if 'head.fc.weight' in state_dict: |
497 | | - return state_dict # non-MSFT checkpoint |
498 | 497 |
|
499 | 498 | if 'state_dict' in state_dict: |
500 | 499 | state_dict = state_dict['state_dict'] |
@@ -524,14 +523,13 @@ def _create_cvt(variant, pretrained=False, **kwargs): |
524 | 523 |
|
525 | 524 | return model |
526 | 525 |
|
527 | | -# TODO update first_conv |
528 | 526 | def _cfg(url='', **kwargs): |
529 | 527 | return { |
530 | 528 | 'url': url, |
531 | 529 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (14, 14), |
532 | 530 | 'crop_pct': 0.95, 'interpolation': 'bicubic', |
533 | 531 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, |
534 | | - 'first_conv': 'stem.conv', 'classifier': 'head', |
| 532 | + 'first_conv': 'stages.0.conv_embed.conv', 'classifier': 'head', |
535 | 533 | **kwargs |
536 | 534 | } |
537 | 535 |
|
|
0 commit comments