@@ -84,24 +84,34 @@ def _cfg(url='', **kwargs):
8484 url = '' , input_size = (3 , 380 , 380 ), pool_size = (12 , 12 ), crop_pct = 0.922 ),
8585 'efficientnet_b5' : _cfg (
8686 url = '' , input_size = (3 , 456 , 456 ), pool_size = (15 , 15 ), crop_pct = 0.934 ),
87+ 'efficientnet_b6' : _cfg (
88+ url = '' , input_size = (3 , 528 , 528 ), pool_size = (17 , 17 ), crop_pct = 0.942 ),
89+ 'efficientnet_b7' : _cfg (
90+ url = '' , input_size = (3 , 600 , 600 ), pool_size = (19 , 19 ), crop_pct = 0.949 ),
8791 'tf_efficientnet_b0' : _cfg (
88- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548 .pth' ,
92+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33 .pth' ,
8993 input_size = (3 , 224 , 224 )),
9094 'tf_efficientnet_b1' : _cfg (
91- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4 .pth' ,
95+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0 .pth' ,
9296 input_size = (3 , 240 , 240 ), pool_size = (8 , 8 ), crop_pct = 0.882 ),
9397 'tf_efficientnet_b2' : _cfg (
94- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04 .pth' ,
98+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97 .pth' ,
9599 input_size = (3 , 260 , 260 ), pool_size = (9 , 9 ), crop_pct = 0.890 ),
96100 'tf_efficientnet_b3' : _cfg (
97- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955 .pth' ,
101+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e .pth' ,
98102 input_size = (3 , 300 , 300 ), pool_size = (10 , 10 ), crop_pct = 0.904 ),
99103 'tf_efficientnet_b4' : _cfg (
100- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed .pth' ,
104+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c .pth' ,
101105 input_size = (3 , 380 , 380 ), pool_size = (12 , 12 ), crop_pct = 0.922 ),
102106 'tf_efficientnet_b5' : _cfg (
103- url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9 .pth' ,
107+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74 .pth' ,
104108 input_size = (3 , 456 , 456 ), pool_size = (15 , 15 ), crop_pct = 0.934 ),
109+ 'tf_efficientnet_b6' : _cfg (
110+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth' ,
111+ input_size = (3 , 528 , 528 ), pool_size = (17 , 17 ), crop_pct = 0.942 ),
112+ 'tf_efficientnet_b7' : _cfg (
113+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth' ,
114+ input_size = (3 , 600 , 600 ), pool_size = (19 , 19 ), crop_pct = 0.949 ),
105115 'mixnet_s' : _cfg (url = '' ),
106116 'mixnet_m' : _cfg (
107117 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth' ),
@@ -763,8 +773,6 @@ def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
763773 num_classes = num_classes ,
764774 stem_size = 32 ,
765775 channel_multiplier = channel_multiplier ,
766- channel_divisor = 8 ,
767- channel_min = None ,
768776 bn_args = _resolve_bn_args (kwargs ),
769777 ** kwargs
770778 )
@@ -801,8 +809,6 @@ def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
801809 num_classes = num_classes ,
802810 stem_size = 32 ,
803811 channel_multiplier = channel_multiplier ,
804- channel_divisor = 8 ,
805- channel_min = None ,
806812 bn_args = _resolve_bn_args (kwargs ),
807813 ** kwargs
808814 )
@@ -832,8 +838,6 @@ def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
832838 num_classes = num_classes ,
833839 stem_size = 8 ,
834840 channel_multiplier = channel_multiplier ,
835- channel_divisor = 8 ,
836- channel_min = None ,
837841 bn_args = _resolve_bn_args (kwargs ),
838842 ** kwargs
839843 )
@@ -858,8 +862,6 @@ def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
858862 stem_size = 32 ,
859863 num_features = 1024 ,
860864 channel_multiplier = channel_multiplier ,
861- channel_divisor = 8 ,
862- channel_min = None ,
863865 bn_args = _resolve_bn_args (kwargs ),
864866 act_fn = F .relu6 ,
865867 head_conv = 'none' ,
@@ -887,8 +889,6 @@ def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
887889 num_classes = num_classes ,
888890 stem_size = 32 ,
889891 channel_multiplier = channel_multiplier ,
890- channel_divisor = 8 ,
891- channel_min = None ,
892892 bn_args = _resolve_bn_args (kwargs ),
893893 act_fn = F .relu6 ,
894894 ** kwargs
@@ -926,8 +926,6 @@ def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
926926 num_classes = num_classes ,
927927 stem_size = 16 ,
928928 channel_multiplier = channel_multiplier ,
929- channel_divisor = 8 ,
930- channel_min = None ,
931929 bn_args = _resolve_bn_args (kwargs ),
932930 act_fn = hard_swish ,
933931 se_gate_fn = hard_sigmoid ,
@@ -961,8 +959,6 @@ def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
961959 stem_size = 32 ,
962960 num_features = 1280 , # no idea what this is? try mobile/mnasnet default?
963961 channel_multiplier = channel_multiplier ,
964- channel_divisor = 8 ,
965- channel_min = None ,
966962 bn_args = _resolve_bn_args (kwargs ),
967963 ** kwargs
968964 )
@@ -992,8 +988,6 @@ def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
992988 stem_size = 32 ,
993989 num_features = 1280 , # no idea what this is? try mobile/mnasnet default?
994990 channel_multiplier = channel_multiplier ,
995- channel_divisor = 8 ,
996- channel_min = None ,
997991 bn_args = _resolve_bn_args (kwargs ),
998992 ** kwargs
999993 )
@@ -1024,8 +1018,6 @@ def _gen_fbnetc(channel_multiplier, num_classes=1000, **kwargs):
10241018 stem_size = 16 ,
10251019 num_features = 1984 , # paper suggests this, but is not 100% clear
10261020 channel_multiplier = channel_multiplier ,
1027- channel_divisor = 8 ,
1028- channel_min = None ,
10291021 bn_args = _resolve_bn_args (kwargs ),
10301022 ** kwargs
10311023 )
@@ -1061,8 +1053,6 @@ def _gen_spnasnet(channel_multiplier, num_classes=1000, **kwargs):
10611053 num_classes = num_classes ,
10621054 stem_size = 32 ,
10631055 channel_multiplier = channel_multiplier ,
1064- channel_divisor = 8 ,
1065- channel_min = None ,
10661056 bn_args = _resolve_bn_args (kwargs ),
10671057 ** kwargs
10681058 )
@@ -1107,8 +1097,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
11071097 num_classes = num_classes ,
11081098 stem_size = 32 ,
11091099 channel_multiplier = channel_multiplier ,
1110- channel_divisor = 8 ,
1111- channel_min = None ,
11121100 num_features = num_features ,
11131101 bn_args = _resolve_bn_args (kwargs ),
11141102 act_fn = swish ,
@@ -1144,8 +1132,6 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
11441132 stem_size = 16 ,
11451133 num_features = 1536 ,
11461134 channel_multiplier = channel_multiplier ,
1147- channel_divisor = 8 ,
1148- channel_min = None ,
11491135 bn_args = _resolve_bn_args (kwargs ),
11501136 act_fn = F .relu ,
11511137 ** kwargs
@@ -1180,8 +1166,6 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
11801166 stem_size = 24 ,
11811167 num_features = 1536 ,
11821168 channel_multiplier = channel_multiplier ,
1183- channel_divisor = 8 ,
1184- channel_min = None ,
11851169 bn_args = _resolve_bn_args (kwargs ),
11861170 act_fn = F .relu ,
11871171 ** kwargs
@@ -1495,6 +1479,37 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
14951479 return model
14961480
14971481
1482+
1483+ @register_model
1484+ def efficientnet_b6 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1485+ """ EfficientNet-B6 """
1486+ # NOTE for train, drop_rate should be 0.5
1487+ #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
1488+ default_cfg = default_cfgs ['efficientnet_b6' ]
1489+ model = _gen_efficientnet (
1490+ channel_multiplier = 1.8 , depth_multiplier = 2.6 ,
1491+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1492+ model .default_cfg = default_cfg
1493+ if pretrained :
1494+ load_pretrained (model , default_cfg , num_classes , in_chans )
1495+ return model
1496+
1497+
1498+ @register_model
1499+ def efficientnet_b7 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1500+ """ EfficientNet-B7 """
1501+ # NOTE for train, drop_rate should be 0.5
1502+ #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg
1503+ default_cfg = default_cfgs ['efficientnet_b7' ]
1504+ model = _gen_efficientnet (
1505+ channel_multiplier = 2.0 , depth_multiplier = 3.1 ,
1506+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1507+ model .default_cfg = default_cfg
1508+ if pretrained :
1509+ load_pretrained (model , default_cfg , num_classes , in_chans )
1510+ return model
1511+
1512+
14981513@register_model
14991514def tf_efficientnet_b0 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
15001515 """ EfficientNet-B0. Tensorflow compatible variant """
@@ -1585,6 +1600,38 @@ def tf_efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
15851600 return model
15861601
15871602
1603+ @register_model
1604+ def tf_efficientnet_b6 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1605+ """ EfficientNet-B6. Tensorflow compatible variant """
1606+ # NOTE for train, drop_rate should be 0.5
1607+ default_cfg = default_cfgs ['tf_efficientnet_b6' ]
1608+ kwargs ['bn_eps' ] = _BN_EPS_TF_DEFAULT
1609+ kwargs ['pad_type' ] = 'same'
1610+ model = _gen_efficientnet (
1611+ channel_multiplier = 1.8 , depth_multiplier = 2.6 ,
1612+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1613+ model .default_cfg = default_cfg
1614+ if pretrained :
1615+ load_pretrained (model , default_cfg , num_classes , in_chans )
1616+ return model
1617+
1618+
1619+ @register_model
1620+ def tf_efficientnet_b7 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1621+ """ EfficientNet-B7. Tensorflow compatible variant """
1622+ # NOTE for train, drop_rate should be 0.5
1623+ default_cfg = default_cfgs ['tf_efficientnet_b7' ]
1624+ kwargs ['bn_eps' ] = _BN_EPS_TF_DEFAULT
1625+ kwargs ['pad_type' ] = 'same'
1626+ model = _gen_efficientnet (
1627+ channel_multiplier = 2.0 , depth_multiplier = 3.1 ,
1628+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1629+ model .default_cfg = default_cfg
1630+ if pretrained :
1631+ load_pretrained (model , default_cfg , num_classes , in_chans )
1632+ return model
1633+
1634+
15881635@register_model
15891636def mixnet_s (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
15901637 """Creates a MixNet Small model.
0 commit comments