Skip to content

Commit 8fac5b5

Browse files
Merge pull request #293 from zezeze97/master
Add SMILE
2 parents b3f8c06 + 51ab3ed commit 8fac5b5

File tree

11 files changed

+1897
-0
lines changed

11 files changed

+1897
-0
lines changed

CV/SMILE/README.md

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SMILE: Self-Distilled MIxup for Efficient Transfer LEarning
2+
## Introduction
3+
4+
This is the [PaddlePaddle](https://www.paddlepaddle.org.cn/) implementation of the SMILE (Spotlight on [INTERPOLATE@NeurIPS 2022](https://sites.google.com/view/interpolation-workshop?pli=1)) model for image classification.
5+
6+
In this work, we propose SMILE— Self-Distilled Mixup for EffIcient Transfer LEarning.
7+
With mixed images as inputs, SMILE regularizes the outputs of CNN feature extractors to learn
8+
from the mixed feature vectors of inputs (sample-to-feature mixup), in addition to the mixed labels.
9+
Specifically, SMILE incorporates a mean teacher, inherited from the pre-trained model, to provide
10+
the feature vectors of input samples in a self-distilling fashion, and mixes up the feature vectors
11+
accordingly via a novel triplet regularizer. The triple regularizer balances the mixup effects in both
12+
feature and label spaces while bounding the linearity in-between samples for pre-training tasks.
13+
14+
15+
16+
## Requirements
17+
The code has been tested running under the following environments:
18+
19+
* python >= 3.7
20+
* numpy >= 1.21
21+
* paddlepaddle >= 2.2 (with suitable CUDA and cuDNN version)
22+
* visualdl
23+
24+
25+
26+
## Model Training
27+
28+
### step1. Download dataset files
29+
We conduct experiments on three popular object recognition datasets: CUB-200-2011, Stanford Cars and
30+
FGVC-Aircraft. You can download it from the official link below.
31+
- [CUB-200-2011](http://www.vision.caltech.edu/datasets/cub_200_2011/)
32+
- [Stanford Cars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
33+
- [FGVC-Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)
34+
35+
Please organize your dataset in the following format.
36+
```
37+
dataset
38+
├── train
39+
│ ├── class_001
40+
| | ├── 1.jpg
41+
| | ├── 2.jpg
42+
| | └── ...
43+
│ ├── class_002
44+
| | ├── 1.jpg
45+
| | ├── 2.jpg
46+
| | └── ...
47+
│ └── ...
48+
└── test
49+
├── class_001
50+
| ├── 1.jpg
51+
| ├── 2.jpg
52+
| └── ...
53+
├── class_002
54+
| ├── 1.jpg
55+
| ├── 2.jpg
56+
| └── ...
57+
└── ...
58+
```
59+
60+
### step2. Finetune
61+
62+
You can use the following command to finetune the target data using the SMILE algorithm. Log files and ckpts during training are saved in the ./output. Only the model with the highest accuracy on the validation set is saved during finetuning.
63+
```
64+
python finetune.py --name {name of your experiment} --train_dir {path of train dir} --eval_dir {path of eval dir} --model_arch resnet50 --gpu {gpu id} --regularizer smile
65+
```
66+
67+
### step3. Test
68+
69+
You can also load the finetuning ckpts with the following command and test it on the test set.
70+
```
71+
python test.py --test_dir {path of test dir} --model_arch resnet50 --gpu {gpu id} --ckpts {path of finetuning ckpts}
72+
```
73+
74+
## Results
75+
76+
|Dataset/Method | L2 | SMILE |
77+
|---|---|---|
78+
|CUB-200-2011 | 80.79 | 82.38 |
79+
|Stanford-Cars| 90.72 | 91.74 |
80+
|FGVC-Aircraft| 86.93 | 89.00 |
81+
82+
83+
84+
## Citation
85+
If you use any source code included in this project in your work, please cite the following paper:
86+
87+
```
88+
@article{Li2021SMILESM,
89+
title={SMILE: Self-Distilled MIxup for Efficient Transfer LEarning},
90+
author={Xingjian Li and Haoyi Xiong and Chengzhong Xu and Dejing Dou},
91+
journal={ArXiv},
92+
year={2021},
93+
volume={abs/2103.13941}
94+
}
95+
```
96+
97+
## Copyright and License
98+
Copyright 2019 Baidu.com, Inc. All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

CV/SMILE/backbones/__init__.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
#Licensed under the Apache License, Version 2.0 (the "License");
4+
#you may not use this file except in compliance with the License.
5+
#You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
#Unless required by applicable law or agreed to in writing, software
10+
#distributed under the License is distributed on an "AS IS" BASIS,
11+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#See the License for the specific language governing permissions and
13+
#limitations under the License.
14+
15+
from .resnet import ResNet # noqa: F401
16+
from .resnet import resnet18 # noqa: F401
17+
from .resnet import resnet34 # noqa: F401
18+
from .resnet import resnet50 # noqa: F401
19+
from .resnet import resnet101 # noqa: F401
20+
from .resnet import resnet152 # noqa: F401
21+
from .mobilenetv2 import MobileNetV2 # noqa: F401
22+
from .mobilenetv2 import mobilenet_v2 # noqa: F401
23+
from .vit import VisionTransformer
24+
from .vit import build_vit
25+
'''from .mobilenetv1 import MobileNetV1 # noqa: F401
26+
from .mobilenetv1 import mobilenet_v1 # noqa: F401
27+
from .mobilenetv2 import MobileNetV2 # noqa: F401
28+
from .mobilenetv2 import mobilenet_v2 # noqa: F401
29+
from .vgg import VGG # noqa: F401
30+
from .vgg import vgg11 # noqa: F401
31+
from .vgg import vgg13 # noqa: F401
32+
from .vgg import vgg16 # noqa: F401
33+
from .vgg import vgg19 # noqa: F401
34+
from .lenet import LeNet # noqa: F401'''
35+
36+
__all__ = [ #noqa
37+
'ResNet',
38+
'resnet18',
39+
'resnet34',
40+
'resnet50',
41+
'resnet101',
42+
'resnet152',
43+
'VGG',
44+
'vgg11',
45+
'vgg13',
46+
'vgg16',
47+
'vgg19',
48+
'MobileNetV1',
49+
'mobilenet_v1',
50+
'MobileNetV2',
51+
'mobilenet_v2',
52+
'LeNet',
53+
'ViT'
54+
]

CV/SMILE/backbones/mobilenetv2.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import paddle
17+
18+
import paddle.nn as nn
19+
import paddle.nn.functional as F
20+
21+
from paddle.utils.download import get_weights_path_from_url
22+
23+
__all__ = []
24+
25+
model_urls = {
26+
'mobilenetv2_1.0':
27+
('https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams',
28+
'0340af0a901346c8d46f4529882fb63d')
29+
}
30+
31+
32+
def _make_divisible(v, divisor, min_value=None):
33+
if min_value is None:
34+
min_value = divisor
35+
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
36+
37+
if new_v < 0.9 * v:
38+
new_v += divisor
39+
return new_v
40+
41+
42+
class ConvBNReLU(nn.Sequential):
43+
def __init__(self,
44+
in_planes,
45+
out_planes,
46+
kernel_size=3,
47+
stride=1,
48+
groups=1,
49+
norm_layer=nn.BatchNorm2D):
50+
padding = (kernel_size - 1) // 2
51+
52+
super(ConvBNReLU, self).__init__(
53+
nn.Conv2D(
54+
in_planes,
55+
out_planes,
56+
kernel_size,
57+
stride,
58+
padding,
59+
groups=groups,
60+
bias_attr=False),
61+
norm_layer(out_planes),
62+
nn.ReLU6())
63+
64+
65+
class InvertedResidual(nn.Layer):
66+
def __init__(self,
67+
inp,
68+
oup,
69+
stride,
70+
expand_ratio,
71+
norm_layer=nn.BatchNorm2D):
72+
super(InvertedResidual, self).__init__()
73+
self.stride = stride
74+
assert stride in [1, 2]
75+
76+
hidden_dim = int(round(inp * expand_ratio))
77+
self.use_res_connect = self.stride == 1 and inp == oup
78+
79+
layers = []
80+
if expand_ratio != 1:
81+
layers.append(
82+
ConvBNReLU(
83+
inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
84+
layers.extend([
85+
ConvBNReLU(
86+
hidden_dim,
87+
hidden_dim,
88+
stride=stride,
89+
groups=hidden_dim,
90+
norm_layer=norm_layer),
91+
nn.Conv2D(
92+
hidden_dim, oup, 1, 1, 0, bias_attr=False),
93+
norm_layer(oup),
94+
])
95+
self.conv = nn.Sequential(*layers)
96+
97+
def forward(self, x):
98+
if self.use_res_connect:
99+
return x + self.conv(x)
100+
else:
101+
return self.conv(x)
102+
103+
104+
class MobileNetV2(nn.Layer):
105+
def __init__(self, scale=1.0, num_classes=1000, with_pool=True):
106+
"""MobileNetV2 model from
107+
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
108+
109+
Args:
110+
scale (float): scale of channels in each layer. Default: 1.0.
111+
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
112+
will not be defined. Default: 1000.
113+
with_pool (bool): use pool before the last fc layer or not. Default: True.
114+
115+
Examples:
116+
.. code-block:: python
117+
118+
from paddle.vision.models import MobileNetV2
119+
120+
model = MobileNetV2()
121+
"""
122+
super(MobileNetV2, self).__init__()
123+
self.num_classes = num_classes
124+
self.with_pool = with_pool
125+
input_channel = 32
126+
last_channel = 1280
127+
128+
block = InvertedResidual
129+
round_nearest = 8
130+
norm_layer = nn.BatchNorm2D
131+
inverted_residual_setting = [
132+
[1, 16, 1, 1],
133+
[6, 24, 2, 2],
134+
[6, 32, 3, 2],
135+
[6, 64, 4, 2],
136+
[6, 96, 3, 1],
137+
[6, 160, 3, 2],
138+
[6, 320, 1, 1],
139+
]
140+
141+
input_channel = _make_divisible(input_channel * scale, round_nearest)
142+
self.last_channel = _make_divisible(last_channel * max(1.0, scale),
143+
round_nearest)
144+
features = [
145+
ConvBNReLU(
146+
3, input_channel, stride=2, norm_layer=norm_layer)
147+
]
148+
149+
for t, c, n, s in inverted_residual_setting:
150+
output_channel = _make_divisible(c * scale, round_nearest)
151+
for i in range(n):
152+
stride = s if i == 0 else 1
153+
features.append(
154+
block(
155+
input_channel,
156+
output_channel,
157+
stride,
158+
expand_ratio=t,
159+
norm_layer=norm_layer))
160+
input_channel = output_channel
161+
162+
features.append(
163+
ConvBNReLU(
164+
input_channel,
165+
self.last_channel,
166+
kernel_size=1,
167+
norm_layer=norm_layer))
168+
169+
self.features = nn.Sequential(*features)
170+
171+
if with_pool:
172+
self.pool2d_avg = nn.AdaptiveAvgPool2D(1)
173+
174+
if self.num_classes > 0:
175+
self.classifier = nn.Sequential(
176+
nn.Dropout(0.2), nn.Linear(self.last_channel, num_classes))
177+
def forward(self, x):
178+
fea = self.features(x)
179+
180+
if self.with_pool:
181+
x = self.pool2d_avg(fea)
182+
else:
183+
x = fea
184+
185+
if self.num_classes > 0:
186+
x = paddle.flatten(x, 1)
187+
x = self.classifier(x)
188+
return x, fea
189+
190+
191+
def _mobilenet(arch, pretrained=False, **kwargs):
192+
model = MobileNetV2(**kwargs)
193+
if pretrained:
194+
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
195+
arch)
196+
weight_path = get_weights_path_from_url(model_urls[arch][0],
197+
model_urls[arch][1])
198+
199+
param = paddle.load(weight_path)
200+
model.load_dict(param)
201+
202+
return model
203+
204+
205+
def mobilenet_v2(pretrained=False, scale=1.0, **kwargs):
206+
"""MobileNetV2
207+
208+
Args:
209+
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
210+
scale: (float): scale of channels in each layer. Default: 1.0.
211+
212+
Examples:
213+
.. code-block:: python
214+
215+
from paddle.vision.models import mobilenet_v2
216+
217+
# build model
218+
model = mobilenet_v2()
219+
220+
# build model and load imagenet pretrained weight
221+
# model = mobilenet_v2(pretrained=True)
222+
223+
# build mobilenet v2 with scale=0.5
224+
model = mobilenet_v2(scale=0.5)
225+
"""
226+
model = _mobilenet(
227+
'mobilenetv2_' + str(scale), pretrained, scale=scale, **kwargs)
228+
return model

0 commit comments

Comments
 (0)