@@ -532,7 +532,11 @@ def _apply_learned_pos_embed(
532532 pos_embed_flat = self .pos_embed .reshape (1 , orig_h * orig_w , - 1 )
533533 else :
534534 # Resize if needed - directly using F.interpolate
535- _interp_size = to_2tuple (max (grid_size )) if self .pos_embed_ar_preserving else grid_size
535+ if self .pos_embed_ar_preserving :
536+ L = max (grid_size )
537+ _interp_size = L , L
538+ else :
539+ _interp_size = grid_size
536540 pos_embed_flat = F .interpolate (
537541 self .pos_embed .permute (0 , 3 , 1 , 2 ).float (), # B,C,H,W
538542 size = _interp_size ,
@@ -968,7 +972,7 @@ def __init__(
968972 cfg: Model configuration. If None, uses default NaFlexVitCfg.
969973 in_chans: Number of input image channels.
970974 num_classes: Number of classification classes.
971- img_size: Input image size for backwards compatibility.
975+ img_size: Input image size ( for backwards compatibility with classic vit) .
972976 **kwargs: Additional config parameters to override cfg values.
973977 """
974978 super ().__init__ ()
@@ -1523,9 +1527,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
15231527 return {
15241528 'url' : url ,
15251529 'num_classes' : 1000 ,
1526- 'input_size' : (3 , 256 , 256 ),
1530+ 'input_size' : (3 , 384 , 384 ),
15271531 'pool_size' : None ,
1528- 'crop_pct' : 0.95 ,
1532+ 'crop_pct' : 1.0 ,
15291533 'interpolation' : 'bicubic' ,
15301534 'mean' : IMAGENET_INCEPTION_MEAN ,
15311535 'std' : IMAGENET_INCEPTION_STD ,
@@ -1537,11 +1541,19 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
15371541
15381542
15391543default_cfgs = generate_default_cfgs ({
1540- 'naflexvit_base_patch16_gap' : _cfg (),
1541- 'naflexvit_base_patch16_map' : _cfg (),
1542-
1543- 'naflexvit_base_patch16_siglip' : _cfg (),
1544- 'naflexvit_so400m_patch16_siglip' : _cfg (),
1544+ 'naflexvit_base_patch16_gap.e300_s576_in1k' : _cfg (
1545+ hf_hub_id = 'timm/' ,
1546+ ),
1547+ 'naflexvit_base_patch16_par_gap.e300_s576_in1k' : _cfg (
1548+ hf_hub_id = 'timm/' ,
1549+ ),
1550+ 'naflexvit_base_patch16_parfac_gap.e300_s576_in1k' : _cfg (
1551+ hf_hub_id = 'timm/' ,
1552+ ),
1553+ 'naflexvit_base_patch16_map.untrained' : _cfg (),
1554+
1555+ 'naflexvit_base_patch16_siglip.untrained' : _cfg (),
1556+ 'naflexvit_so400m_patch16_siglip.untrained' : _cfg (),
15451557})
15461558
15471559
@@ -1623,6 +1635,45 @@ def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
16231635 return model
16241636
16251637
1638+ @register_model
1639+ def naflexvit_base_patch16_par_gap (pretrained : bool = False , ** kwargs ) -> NaFlexVit :
1640+ """ViT-Base with NaFlex functionality, aspect preserving pos embed, global average pooling.
1641+ """
1642+ cfg = NaFlexVitCfg (
1643+ patch_size = 16 ,
1644+ embed_dim = 768 ,
1645+ depth = 12 ,
1646+ num_heads = 12 ,
1647+ init_values = 1e-5 ,
1648+ pos_embed_ar_preserving = True ,
1649+ global_pool = 'avg' ,
1650+ reg_tokens = 4 ,
1651+ fc_norm = True ,
1652+ )
1653+ model = _create_naflexvit ('naflexvit_base_patch16_par_gap' , pretrained = pretrained , cfg = cfg , ** kwargs )
1654+ return model
1655+
1656+
1657+ @register_model
1658+ def naflexvit_base_patch16_parfac_gap (pretrained : bool = False , ** kwargs ) -> NaFlexVit :
1659+ """ViT-Base with NaFlex functionality, aspect preserving & factorized pos embed, global average pooling.
1660+ """
1661+ cfg = NaFlexVitCfg (
1662+ patch_size = 16 ,
1663+ embed_dim = 768 ,
1664+ depth = 12 ,
1665+ num_heads = 12 ,
1666+ init_values = 1e-5 ,
1667+ pos_embed_ar_preserving = True ,
1668+ pos_embed = 'factorized' ,
1669+ global_pool = 'avg' ,
1670+ reg_tokens = 4 ,
1671+ fc_norm = True ,
1672+ )
1673+ model = _create_naflexvit ('naflexvit_base_patch16_parfac_gap' , pretrained = pretrained , cfg = cfg , ** kwargs )
1674+ return model
1675+
1676+
16261677@register_model
16271678def naflexvit_base_patch16_map (pretrained : bool = False , ** kwargs ) -> NaFlexVit :
16281679 """ViT-Base with NaFlex functionality and MAP attention pooling.
0 commit comments