Skip to content

Commit f7b344b

Browse files
author
Arthur Douillard
committed
[e2e] Add final perf of e2e (hard-coded params TOFIX).
1 parent 6e52b49 commit f7b344b

File tree

1 file changed

+66
-37
lines changed

1 file changed

+66
-37
lines changed

inclearn/models/e2e.py

Lines changed: 66 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from tqdm import trange
66

77
from inclearn import factory, utils
8+
from inclearn.lib import callbacks
89
from inclearn.models.base import IncrementalLearner
910

1011

@@ -29,16 +30,14 @@ def __init__(self, args):
2930
self._k = args["memory_size"]
3031
self._n_classes = args["increment"]
3132

32-
self._temperature = args["temperature"]
33+
self._temperature = 2.#args["temperature"]
3334

3435
self._features_extractor = factory.get_resnet(
3536
args["convnet"], nf=64, zero_init_residual=True
3637
)
37-
self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=False)
38-
torch.nn.init.kaiming_normal_(self._classifier.weight)
38+
self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=True)
3939

4040
self._examplars = {}
41-
self._means = None
4241

4342
self.to(self._device)
4443

@@ -65,28 +64,48 @@ def _before_task(self, train_loader, val_loader):
6564
else:
6665
print("Computing previous predictions...")
6766
self._previous_preds = self._compute_predictions(train_loader)
68-
if val_loader:
69-
self._previous_preds_val = self._compute_predictions(val_loader)
7067

7168
self._add_n_classes(self._task_size)
7269

7370
def _train_task(self, train_loader, val_loader):
71+
"""Train & fine-tune model.
72+
73+
The scheduling is different from the paper for one reason. In the paper,
74+
End-to-End Incremental Learning, the authors pre-generated 12 augmentations
75+
per images (thus multiplying by this number the dataset size). However
76+
I find this inefficient for large scale datasets, thus I'm simply doing
77+
the augmentations online. A greater number of epochs is then needed to
78+
match performances.
79+
80+
:param train_loader: A DataLoader.
81+
:param val_loader: A DataLoader, can be None.
82+
"""
7483
# Training on all new + examplars
75-
self.foo = 0
76-
optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.1, 0.0001)
77-
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20, 30], gamma=0.1)
78-
self._train(train_loader, 1, optimizer, scheduler)
84+
self._best_acc = float("-inf")
85+
86+
print("Training")
87+
self._finetuning = False
88+
optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.1, 0.001)
89+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [50, 60], gamma=0.2)
90+
self._train(train_loader, val_loader, 70, optimizer, scheduler)
7991

8092
if self._task == 0:
93+
print("best", self._best_acc)
8194
return
8295

8396
# Fine-tuning on sub-set new + examplars
84-
self._build_examplars(train_loader)
97+
print("Fine-tuning")
98+
self._finetuning = True
99+
self._build_examplars(train_loader,
100+
n_examplars=self._k // (self._n_classes - self._task_size))
85101
train_loader.dataset.set_idxes(self.examplars) # Fine-tuning only on balanced dataset
86-
optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.01, 0.0001)
87-
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], gamma=0.1)
88-
self.foo = 1
89-
self._train(train_loader, 1, optimizer, scheduler)
102+
self._previous_preds = self._compute_predictions(train_loader)
103+
104+
optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.01, 0.001)
105+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 40], gamma=0.2)
106+
self._train(train_loader, val_loader, 50, optimizer, scheduler)
107+
108+
print("best", self._best_acc)
90109

91110
def _after_task(self, data_loader):
92111
self._reduce_examplars()
@@ -105,12 +124,17 @@ def get_memory_indexes(self):
105124
# Private API
106125
# -----------
107126

108-
def _train(self, train_loader, n_epochs, optimizer, scheduler):
127+
def _train(self, train_loader, val_loader, n_epochs, optimizer, scheduler):
109128
print("nb ", len(train_loader.dataset))
110-
111129
prog_bar = trange(n_epochs, desc="Losses.")
112130

131+
val_acc = 0.
113132
for epoch in prog_bar:
133+
if epoch % 10 == 0 and val_loader:
134+
ypred, ytrue = self._classify(val_loader)
135+
val_acc = (ypred == ytrue).sum() / len(ytrue)
136+
self._best_acc = max(self._best_acc, val_acc)
137+
114138
_clf_loss, _distil_loss = 0., 0.
115139
c = 0
116140

@@ -143,14 +167,17 @@ def _train(self, train_loader, n_epochs, optimizer, scheduler):
143167

144168
if i % 10 == 0 or i >= len(train_loader):
145169
prog_bar.set_description(
146-
"Clf loss: {}; Distill loss: {}".format(
147-
round(clf_loss.item(), 3), round(distil_loss.item(), 3)
170+
"Clf loss: {}; Distill loss: {}; Val acc: {}".format(
171+
round(clf_loss.item(), 3), round(distil_loss.item(), 3),
172+
round(val_acc, 3)
148173
)
149174
)
150175

176+
151177
prog_bar.set_description(
152-
"Clf loss: {}; Distill loss: {}".format(
153-
round(_clf_loss / c, 3), round(_distil_loss / c, 3)
178+
"Clf loss: {}; Distill loss: {}; Val acc: {}".format(
179+
round(_clf_loss / c, 3), round(_distil_loss / c, 3),
180+
round(val_acc, 3)
154181
)
155182
)
156183

@@ -165,18 +192,17 @@ def _compute_loss(self, logits, targets, idxes):
165192
match the previous predictions.
166193
:return: A tuple of the classification loss and the distillation loss.
167194
"""
195+
clf_loss = F.cross_entropy(logits, targets)
196+
168197
if self._task == 0:
169-
clf_loss = F.cross_entropy(logits, targets)
170198
distil_loss = torch.zeros(1, device=self._device)
171199
else:
172-
# Disable the cross_entropy loss for the old targets:
173-
for i in range(self._new_task_index):
174-
targets[targets == i] = -1
175-
clf_loss = F.cross_entropy(logits, targets, ignore_index=-1)
200+
if not self._finetuning:
201+
logits = logits[..., :self._new_task_index]
176202

177203
distil_loss = F.binary_cross_entropy(
178-
F.softmax(logits[..., :self._new_task_index] ** (1 / self._temperature), dim=1),
179-
F.softmax(self._previous_preds[idxes]**(1 / self._temperature), dim=1)
204+
F.softmax(logits / self._temperature, dim=1),
205+
F.softmax(self._previous_preds[idxes] / self._temperature, dim=1)
180206
)
181207

182208
return clf_loss, distil_loss
@@ -210,7 +236,7 @@ def _classify(self, loader):
210236
for _, inputs, targets in loader:
211237
inputs = inputs.to(self._device)
212238
logits = self.forward(inputs)
213-
preds = F.softmax(logits, dim=1).argmax(dim=1)
239+
preds = logits.argmax(dim=1).cpu().numpy()
214240

215241
ypred.extend(preds)
216242
ytrue.extend(targets)
@@ -225,11 +251,9 @@ def _m(self):
225251
def _add_n_classes(self, n):
226252
self._n_classes += n
227253

228-
weights = self._classifier.weight.data
254+
weights = self._classifier.weight.data.clone()
229255
self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes,
230-
bias=False).to(self._device)
231-
torch.nn.init.kaiming_normal_(self._classifier.weight)
232-
256+
bias=True).to(self._device)
233257
self._classifier.weight.data[:self._n_classes - n] = weights
234258

235259
print("Now {} examplars per class.".format(self._m))
@@ -276,31 +300,36 @@ def _dist(a, b):
276300
"""
277301
return torch.pow(a - b, 2).sum(-1)
278302

279-
def _build_examplars(self, loader):
303+
def _build_examplars(self, loader, n_examplars=None):
280304
"""Builds new examplars.
281305
282306
:param loader: A DataLoader.
307+
:param n_examplars: Maximum number of examplars to create.
283308
"""
309+
n_examplars = n_examplars or self._m
310+
284311
lo, hi = self._task * self._task_size, self._n_classes
285312
print("Building examplars for classes {} -> {}.".format(lo, hi))
286313
for class_idx in range(lo, hi):
287314
loader.dataset.set_classes_range(class_idx, class_idx)
288-
self._examplars[class_idx] = self._build_class_examplars(loader)
315+
self._examplars[class_idx] = self._build_class_examplars(loader, n_examplars)
289316

290-
def _build_class_examplars(self, loader):
317+
def _build_class_examplars(self, loader, n_examplars):
291318
"""Build examplars for a single class.
292319
293320
Examplars are selected as the closest to the class mean.
294321
295322
:param loader: DataLoader that provides images for a single class.
323+
:param n_examplars: Maximum number of examplars to create.
296324
:return: The real indexes of the chosen examplars.
297325
"""
298326
features, class_mean, idxes = self._extract_features(loader)
299327

300328
class_mean = F.normalize(class_mean, dim=0)
329+
features = F.normalize(features, dim=1)
301330
distances_to_mean = self._dist(class_mean, features)
302331

303-
nb_examplars = min(self._m, len(features))
332+
nb_examplars = min(n_examplars, len(features))
304333

305334
fake_idxes = distances_to_mean.argsort().cpu().numpy()[:nb_examplars]
306335
return [idxes[idx] for idx in fake_idxes]

0 commit comments

Comments
 (0)