55from tqdm import trange
66
77from inclearn import factory , utils
8+ from inclearn .lib import callbacks
89from inclearn .models .base import IncrementalLearner
910
1011
@@ -29,16 +30,14 @@ def __init__(self, args):
2930 self ._k = args ["memory_size" ]
3031 self ._n_classes = args ["increment" ]
3132
32- self ._temperature = args ["temperature" ]
33+ self ._temperature = 2. # args["temperature"]
3334
3435 self ._features_extractor = factory .get_resnet (
3536 args ["convnet" ], nf = 64 , zero_init_residual = True
3637 )
37- self ._classifier = nn .Linear (self ._features_extractor .out_dim , self ._n_classes , bias = False )
38- torch .nn .init .kaiming_normal_ (self ._classifier .weight )
38+ self ._classifier = nn .Linear (self ._features_extractor .out_dim , self ._n_classes , bias = True )
3939
4040 self ._examplars = {}
41- self ._means = None
4241
4342 self .to (self ._device )
4443
@@ -65,28 +64,48 @@ def _before_task(self, train_loader, val_loader):
6564 else :
6665 print ("Computing previous predictions..." )
6766 self ._previous_preds = self ._compute_predictions (train_loader )
68- if val_loader :
69- self ._previous_preds_val = self ._compute_predictions (val_loader )
7067
7168 self ._add_n_classes (self ._task_size )
7269
7370 def _train_task (self , train_loader , val_loader ):
71+ """Train & fine-tune model.
72+
73+ The scheduling is different from the paper for one reason. In the paper,
74+ End-to-End Incremental Learning, the authors pre-generated 12 augmentations
75+ per images (thus multiplying by this number the dataset size). However
76+ I find this inefficient for large scale datasets, thus I'm simply doing
77+ the augmentations online. A greater number of epochs is then needed to
78+ match performances.
79+
80+ :param train_loader: A DataLoader.
81+ :param val_loader: A DataLoader, can be None.
82+ """
7483 # Training on all new + examplars
75- self .foo = 0
76- optimizer = factory .get_optimizer (self .parameters (), self ._opt_name , 0.1 , 0.0001 )
77- scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , [10 , 20 , 30 ], gamma = 0.1 )
78- self ._train (train_loader , 1 , optimizer , scheduler )
84+ self ._best_acc = float ("-inf" )
85+
86+ print ("Training" )
87+ self ._finetuning = False
88+ optimizer = factory .get_optimizer (self .parameters (), self ._opt_name , 0.1 , 0.001 )
89+ scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , [50 , 60 ], gamma = 0.2 )
90+ self ._train (train_loader , val_loader , 70 , optimizer , scheduler )
7991
8092 if self ._task == 0 :
93+ print ("best" , self ._best_acc )
8194 return
8295
8396 # Fine-tuning on sub-set new + examplars
84- self ._build_examplars (train_loader )
97+ print ("Fine-tuning" )
98+ self ._finetuning = True
99+ self ._build_examplars (train_loader ,
100+ n_examplars = self ._k // (self ._n_classes - self ._task_size ))
85101 train_loader .dataset .set_idxes (self .examplars ) # Fine-tuning only on balanced dataset
86- optimizer = factory .get_optimizer (self .parameters (), self ._opt_name , 0.01 , 0.0001 )
87- scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , [10 , 20 ], gamma = 0.1 )
88- self .foo = 1
89- self ._train (train_loader , 1 , optimizer , scheduler )
102+ self ._previous_preds = self ._compute_predictions (train_loader )
103+
104+ optimizer = factory .get_optimizer (self .parameters (), self ._opt_name , 0.01 , 0.001 )
105+ scheduler = torch .optim .lr_scheduler .MultiStepLR (optimizer , [20 , 40 ], gamma = 0.2 )
106+ self ._train (train_loader , val_loader , 50 , optimizer , scheduler )
107+
108+ print ("best" , self ._best_acc )
90109
91110 def _after_task (self , data_loader ):
92111 self ._reduce_examplars ()
@@ -105,12 +124,17 @@ def get_memory_indexes(self):
105124 # Private API
106125 # -----------
107126
108- def _train (self , train_loader , n_epochs , optimizer , scheduler ):
127+ def _train (self , train_loader , val_loader , n_epochs , optimizer , scheduler ):
109128 print ("nb " , len (train_loader .dataset ))
110-
111129 prog_bar = trange (n_epochs , desc = "Losses." )
112130
131+ val_acc = 0.
113132 for epoch in prog_bar :
133+ if epoch % 10 == 0 and val_loader :
134+ ypred , ytrue = self ._classify (val_loader )
135+ val_acc = (ypred == ytrue ).sum () / len (ytrue )
136+ self ._best_acc = max (self ._best_acc , val_acc )
137+
114138 _clf_loss , _distil_loss = 0. , 0.
115139 c = 0
116140
@@ -143,14 +167,17 @@ def _train(self, train_loader, n_epochs, optimizer, scheduler):
143167
144168 if i % 10 == 0 or i >= len (train_loader ):
145169 prog_bar .set_description (
146- "Clf loss: {}; Distill loss: {}" .format (
147- round (clf_loss .item (), 3 ), round (distil_loss .item (), 3 )
170+ "Clf loss: {}; Distill loss: {}; Val acc: {}" .format (
171+ round (clf_loss .item (), 3 ), round (distil_loss .item (), 3 ),
172+ round (val_acc , 3 )
148173 )
149174 )
150175
176+
151177 prog_bar .set_description (
152- "Clf loss: {}; Distill loss: {}" .format (
153- round (_clf_loss / c , 3 ), round (_distil_loss / c , 3 )
178+ "Clf loss: {}; Distill loss: {}; Val acc: {}" .format (
179+ round (_clf_loss / c , 3 ), round (_distil_loss / c , 3 ),
180+ round (val_acc , 3 )
154181 )
155182 )
156183
@@ -165,18 +192,17 @@ def _compute_loss(self, logits, targets, idxes):
165192 match the previous predictions.
166193 :return: A tuple of the classification loss and the distillation loss.
167194 """
195+ clf_loss = F .cross_entropy (logits , targets )
196+
168197 if self ._task == 0 :
169- clf_loss = F .cross_entropy (logits , targets )
170198 distil_loss = torch .zeros (1 , device = self ._device )
171199 else :
172- # Disable the cross_entropy loss for the old targets:
173- for i in range (self ._new_task_index ):
174- targets [targets == i ] = - 1
175- clf_loss = F .cross_entropy (logits , targets , ignore_index = - 1 )
200+ if not self ._finetuning :
201+ logits = logits [..., :self ._new_task_index ]
176202
177203 distil_loss = F .binary_cross_entropy (
178- F .softmax (logits [..., : self . _new_task_index ] ** ( 1 / self ._temperature ) , dim = 1 ),
179- F .softmax (self ._previous_preds [idxes ]** ( 1 / self ._temperature ) , dim = 1 )
204+ F .softmax (logits / self ._temperature , dim = 1 ),
205+ F .softmax (self ._previous_preds [idxes ] / self ._temperature , dim = 1 )
180206 )
181207
182208 return clf_loss , distil_loss
@@ -210,7 +236,7 @@ def _classify(self, loader):
210236 for _ , inputs , targets in loader :
211237 inputs = inputs .to (self ._device )
212238 logits = self .forward (inputs )
213- preds = F . softmax ( logits , dim = 1 ).argmax ( dim = 1 )
239+ preds = logits . argmax ( dim = 1 ).cpu (). numpy ( )
214240
215241 ypred .extend (preds )
216242 ytrue .extend (targets )
@@ -225,11 +251,9 @@ def _m(self):
225251 def _add_n_classes (self , n ):
226252 self ._n_classes += n
227253
228- weights = self ._classifier .weight .data
254+ weights = self ._classifier .weight .data . clone ()
229255 self ._classifier = nn .Linear (self ._features_extractor .out_dim , self ._n_classes ,
230- bias = False ).to (self ._device )
231- torch .nn .init .kaiming_normal_ (self ._classifier .weight )
232-
256+ bias = True ).to (self ._device )
233257 self ._classifier .weight .data [:self ._n_classes - n ] = weights
234258
235259 print ("Now {} examplars per class." .format (self ._m ))
@@ -276,31 +300,36 @@ def _dist(a, b):
276300 """
277301 return torch .pow (a - b , 2 ).sum (- 1 )
278302
279- def _build_examplars (self , loader ):
303+ def _build_examplars (self , loader , n_examplars = None ):
280304 """Builds new examplars.
281305
282306 :param loader: A DataLoader.
307+ :param n_examplars: Maximum number of examplars to create.
283308 """
309+ n_examplars = n_examplars or self ._m
310+
284311 lo , hi = self ._task * self ._task_size , self ._n_classes
285312 print ("Building examplars for classes {} -> {}." .format (lo , hi ))
286313 for class_idx in range (lo , hi ):
287314 loader .dataset .set_classes_range (class_idx , class_idx )
288- self ._examplars [class_idx ] = self ._build_class_examplars (loader )
315+ self ._examplars [class_idx ] = self ._build_class_examplars (loader , n_examplars )
289316
290- def _build_class_examplars (self , loader ):
317+ def _build_class_examplars (self , loader , n_examplars ):
291318 """Build examplars for a single class.
292319
293320 Examplars are selected as the closest to the class mean.
294321
295322 :param loader: DataLoader that provides images for a single class.
323+ :param n_examplars: Maximum number of examplars to create.
296324 :return: The real indexes of the chosen examplars.
297325 """
298326 features , class_mean , idxes = self ._extract_features (loader )
299327
300328 class_mean = F .normalize (class_mean , dim = 0 )
329+ features = F .normalize (features , dim = 1 )
301330 distances_to_mean = self ._dist (class_mean , features )
302331
303- nb_examplars = min (self . _m , len (features ))
332+ nb_examplars = min (n_examplars , len (features ))
304333
305334 fake_idxes = distances_to_mean .argsort ().cpu ().numpy ()[:nb_examplars ]
306335 return [idxes [idx ] for idx in fake_idxes ]
0 commit comments