33from ..distances import CosineSimilarity
44from ..reducers import AvgNonZeroReducer
55from ..utils import common_functions as c_f
6- from ..utils import multilabel_loss_and_miner_utils as mlmu
76from ..utils import loss_and_miner_utils as lmu
7+ from ..utils .module_with_records import ModuleWithRecords
88from .generic_pair_loss import GenericPairLoss
9-
9+ from . base_loss_wrapper import BaseLossWrapper
1010
1111# adapted from https://github.com/HobbitLong/SupContrast
12+ # modified for multi-supcon
1213class MultiSupConLoss (GenericPairLoss ):
13- def __init__ (self , num_classes , temperature = 0.1 , ** kwargs ):
14+ def __init__ (self , num_classes , temperature = 0.1 , threshold = 0.3 , ** kwargs ):
1415 super ().__init__ (mat_based_loss = True , ** kwargs )
1516 self .temperature = temperature
1617 self .add_to_recordable_attributes (list_of_names = ["temperature" ], is_stat = False )
1718 self .num_classes = num_classes
19+ self .threshold = threshold
1820
19- def _compute_loss (self , mat , pos_mask , neg_mask ):
21+ def _compute_loss (self , mat , pos_mask , neg_mask , multi_val ):
2022 if pos_mask .bool ().any () and neg_mask .bool ().any ():
2123 # if dealing with actual distances, use negative distances
2224 if not self .distance .is_inverted :
@@ -29,7 +31,7 @@ def _compute_loss(self, mat, pos_mask, neg_mask):
2931 mat , keep_mask = (pos_mask + neg_mask ).bool (), add_one = False , dim = 1
3032 )
3133 log_prob = mat - denominator
32- mean_log_prob_pos = (pos_mask * log_prob ).sum (dim = 1 ) / (
34+ mean_log_prob_pos = (multi_val * log_prob * pos_mask ).sum (dim = 1 ) / (
3335 pos_mask .sum (dim = 1 ) + c_f .small_val (mat .dtype )
3436 )
3537
@@ -48,16 +50,22 @@ def get_default_reducer(self):
4850 def get_default_distance (self ):
4951 return CosineSimilarity ()
5052
53+ # ==== class methods below are overriden for adaptability to multi-supcon ====
54+
5155 def mat_based_loss (self , mat , indices_tuple ):
52- a1 , p , a2 , n = indices_tuple
56+ a1 , p , a2 , n , jaccard_mat = indices_tuple
5357 pos_mask , neg_mask = torch .zeros_like (mat ), torch .zeros_like (mat )
5458 pos_mask [a1 , p ] = 1
5559 neg_mask [a2 , n ] = 1
56- return self ._compute_loss (mat , pos_mask , neg_mask )
60+ return self ._compute_loss (mat , pos_mask , neg_mask , jaccard_mat )
5761
5862 def compute_loss (self , embeddings , labels , indices_tuple , ref_emb , ref_labels ):
5963 c_f .labels_or_indices_tuple_required (labels , indices_tuple )
60- indices_tuple = mlmu .convert_to_pairs (indices_tuple , labels , self .num_classes , ref_labels , device = embeddings .device )
64+ indices_tuple = convert_to_pairs (
65+ indices_tuple ,
66+ labels ,
67+ ref_labels ,
68+ threshold = self .threshold )
6169 if all (len (x ) <= 1 for x in indices_tuple ):
6270 return self .zero_losses ()
6371 mat = self .distance (embeddings , ref_emb )
@@ -76,11 +84,276 @@ def forward(
7684 Returns: the loss
7785 """
7886 self .reset_stats ()
79- mlmu . check_shapes_multilabels (embeddings , labels )
80- ref_emb , ref_labels = mlmu . set_ref_emb (embeddings , labels , ref_emb , ref_labels )
87+ check_shapes_multilabels (embeddings , labels )
88+ ref_emb , ref_labels = set_ref_emb (embeddings , labels , ref_emb , ref_labels )
8189 loss_dict = self .compute_loss (
8290 embeddings , labels , indices_tuple , ref_emb , ref_labels
8391 )
8492 self .add_embedding_regularization_to_loss_dict (loss_dict , embeddings )
8593 return self .reducer (loss_dict , embeddings , labels )
8694
95+ # =========================================================================
96+
97+
98+ # ================== cross batch memory for multi-supcon ==================
99+ class CrossBatchMemory4MultiLabel (BaseLossWrapper , ModuleWithRecords ):
100+ def __init__ (self , loss , embedding_size , memory_size = 1024 , miner = None , ** kwargs ):
101+ super ().__init__ (loss = loss , ** kwargs )
102+ self .loss = loss
103+ self .miner = miner
104+ self .embedding_size = embedding_size
105+ self .memory_size = memory_size
106+ self .num_classes = loss .num_classes
107+ self .reset_queue ()
108+ self .add_to_recordable_attributes (
109+ list_of_names = ["embedding_size" , "memory_size" , "queue_idx" ], is_stat = False
110+ )
111+
112+ @staticmethod
113+ def supported_losses ():
114+ return [
115+ "MultiSupConLoss"
116+ ]
117+
118+ @classmethod
119+ def check_loss_support (cls , loss_name ):
120+ if loss_name not in cls .supported_losses ():
121+ raise Exception (f"CrossBatchMemory not supported for { loss_name } " )
122+
123+ def forward (self , embeddings , labels , indices_tuple = None , enqueue_mask = None ):
124+ if indices_tuple is not None and enqueue_mask is not None :
125+ raise ValueError ("indices_tuple and enqueue_mask are mutually exclusive" )
126+ if enqueue_mask is not None :
127+ assert len (enqueue_mask ) == len (embeddings )
128+ else :
129+ assert len (embeddings ) <= len (self .embedding_memory )
130+ self .reset_stats ()
131+ device = embeddings .device
132+ labels = c_f .to_device (labels , device = device )
133+ self .embedding_memory = c_f .to_device (
134+ self .embedding_memory , device = device , dtype = embeddings .dtype
135+ )
136+ self .label_memory = c_f .to_device (
137+ self .label_memory , device = device , dtype = labels .dtype
138+ )
139+
140+ if enqueue_mask is not None :
141+ emb_for_queue = embeddings [enqueue_mask ]
142+ labels_for_queue = labels [enqueue_mask ]
143+ embeddings = embeddings [~ enqueue_mask ]
144+ labels = labels [~ enqueue_mask ]
145+ do_remove_self_comparisons = False
146+ else :
147+ emb_for_queue = embeddings
148+ labels_for_queue = labels
149+ do_remove_self_comparisons = True
150+
151+ queue_batch_size = len (emb_for_queue )
152+ self .add_to_memory (emb_for_queue , labels_for_queue , queue_batch_size )
153+
154+ if not self .has_been_filled :
155+ E_mem = self .embedding_memory [: self .queue_idx ]
156+ L_mem = self .label_memory [: self .queue_idx ]
157+ else :
158+ E_mem = self .embedding_memory
159+ L_mem = self .label_memory
160+
161+ indices_tuple = self .create_indices_tuple (
162+ embeddings ,
163+ labels ,
164+ E_mem ,
165+ L_mem ,
166+ indices_tuple ,
167+ do_remove_self_comparisons ,
168+ )
169+ loss = self .loss (embeddings , labels , indices_tuple , E_mem , L_mem )
170+ return loss
171+
172+ def add_to_memory (self , embeddings , labels , batch_size ):
173+ self .curr_batch_idx = (
174+ torch .arange (
175+ self .queue_idx , self .queue_idx + batch_size , device = labels .device
176+ )
177+ % self .memory_size
178+ )
179+ self .embedding_memory [self .curr_batch_idx ] = embeddings .detach ()
180+ self .label_memory [self .curr_batch_idx ] = labels .detach ()
181+ prev_queue_idx = self .queue_idx
182+ self .queue_idx = (self .queue_idx + batch_size ) % self .memory_size
183+ if (not self .has_been_filled ) and (self .queue_idx <= prev_queue_idx ):
184+ self .has_been_filled = True
185+
186+ def create_indices_tuple (
187+ self ,
188+ embeddings ,
189+ labels ,
190+ E_mem ,
191+ L_mem ,
192+ input_indices_tuple ,
193+ do_remove_self_comparisons ,
194+ ):
195+ if self .miner :
196+ indices_tuple = self .miner (embeddings , labels , E_mem , L_mem )
197+ else :
198+ indices_tuple = get_all_pairs_indices (labels , L_mem )
199+
200+ if do_remove_self_comparisons :
201+ indices_tuple = remove_self_comparisons (
202+ indices_tuple , self .curr_batch_idx , self .memory_size
203+ )
204+
205+ if input_indices_tuple is not None :
206+ if len (input_indices_tuple ) == 3 and len (indices_tuple ) == 4 :
207+ input_indices_tuple = convert_to_pairs (input_indices_tuple , labels )
208+ elif len (input_indices_tuple ) == 4 and len (indices_tuple ) == 3 :
209+ input_indices_tuple = convert_to_triplets (
210+ input_indices_tuple , labels
211+ )
212+ indices_tuple = c_f .concatenate_indices_tuples (
213+ indices_tuple , input_indices_tuple
214+ )
215+
216+ return indices_tuple
217+
218+ def reset_queue (self ):
219+ self .register_buffer (
220+ "embedding_memory" , torch .zeros (self .memory_size , self .embedding_size )
221+ )
222+ self .register_buffer (
223+ "label_memory" , torch .zeros (self .memory_size , self .num_classes )
224+ )
225+ self .has_been_filled = False
226+ self .queue_idx = 0
227+
228+ # =========================================================================
229+
230+ # compute jaccard similarity
231+ def jaccard (labels , ref_labels = None ):
232+ if ref_labels is None :
233+ ref_labels = labels
234+
235+ labels1 = labels .float ()
236+ labels2 = ref_labels .float ()
237+
238+ # compute jaccard similarity
239+ # jaccard = intersection / union
240+ labels1_union = labels1 .sum (- 1 )
241+ labels2_union = labels2 .sum (- 1 )
242+ union = labels1_union .unsqueeze (1 ) + labels2_union .unsqueeze (0 )
243+ intersection = torch .mm (labels1 , labels2 .T )
244+ jaccard_matrix = intersection / (union - intersection )
245+
246+ # return indices of jaccard similarity above threshold
247+ return jaccard_matrix
248+
249+ # ====== methods below are overriden for adaptability to multi-supcon ======
250+
251+ # use jaccard similarity to get matches
252+ def get_matches_and_diffs (labels , ref_labels = None , threshold = 0.3 ):
253+ if ref_labels is None :
254+ ref_labels = labels
255+ jaccard_matrix = jaccard (labels , ref_labels )
256+ matches = torch .where (jaccard_matrix > threshold , 1 , 0 )
257+ diffs = matches ^ 1
258+ if ref_labels is labels :
259+ matches .fill_diagonal_ (0 )
260+ return matches , diffs , jaccard_matrix
261+
262+ def check_shapes_multilabels (embeddings , labels ):
263+ if labels is not None and embeddings .shape [0 ] != labels .shape [0 ]:
264+ raise ValueError ("Number of embeddings must equal number of labels" )
265+ if labels is not None and labels .ndim != 2 :
266+ raise ValueError ("labels must be a 1D tensor of shape (batch_size,)" )
267+
268+
269+ def set_ref_emb (embeddings , labels , ref_emb , ref_labels ):
270+ if ref_emb is None :
271+ ref_emb , ref_labels = embeddings , labels
272+ check_shapes_multilabels (ref_emb , ref_labels )
273+ return ref_emb , ref_labels
274+
275+
276+ def convert_to_pairs (indices_tuple , labels , ref_labels = None , threshold = 0.3 ):
277+ """
278+ This returns anchor-positive and anchor-negative indices,
279+ regardless of what the input indices_tuple is
280+ Args:
281+ indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices
282+ within a batch
283+ labels: a tensor which has the label for each element in a batch
284+ """
285+ if indices_tuple is None :
286+ return get_all_pairs_indices (labels , ref_labels , threshold = threshold )
287+ elif len (indices_tuple ) == 5 :
288+ return indices_tuple
289+ else :
290+ a , p , n , jaccard_mat = indices_tuple
291+ return a , p , a , n ,jaccard_mat
292+
293+
294+ def get_all_pairs_indices (labels , ref_labels = None , threshold = 0.3 ):
295+ """
296+ Given a tensor of labels, this will return 4 tensors.
297+ The first 2 tensors are the indices which form all positive pairs
298+ The second 2 tensors are the indices which form all negative pairs
299+ """
300+ matches , diffs , multi_val = get_matches_and_diffs (labels , ref_labels , threshold = threshold )
301+ a1_idx , p_idx = torch .where (matches )
302+ a2_idx , n_idx = torch .where (diffs )
303+ return a1_idx , p_idx , a2_idx , n_idx , multi_val
304+
305+
306+ def convert_to_triplets (indices_tuple , labels , ref_labels = None , t_per_anchor = 100 ):
307+ """
308+ This returns anchor-positive-negative triplets
309+ regardless of what the input indices_tuple is
310+ """
311+ if indices_tuple is None :
312+ if t_per_anchor == "all" :
313+ return get_all_triplets_indices (labels , ref_labels )
314+ else :
315+ return lmu .get_random_triplet_indices (
316+ labels , ref_labels , t_per_anchor = t_per_anchor
317+ )
318+ elif len (indices_tuple ) == 3 :
319+ return indices_tuple
320+ else :
321+ a1 , p , a2 , n = indices_tuple
322+ p_idx , n_idx = torch .where (a1 .unsqueeze (1 ) == a2 )
323+ return a1 [p_idx ], p [p_idx ], n [n_idx ]
324+
325+
326+ def get_all_triplets_indices (labels , ref_labels = None ):
327+ matches , diffs = get_matches_and_diffs (labels , ref_labels )
328+ triplets = matches .unsqueeze (2 ) * diffs .unsqueeze (1 )
329+ return torch .where (triplets )
330+
331+
332+ def remove_self_comparisons (
333+ indices_tuple , curr_batch_idx , ref_size , ref_is_subset = False
334+ ):
335+ # remove self-comparisons
336+ assert len (indices_tuple ) in [4 , 5 ]
337+ s , e = curr_batch_idx [0 ], curr_batch_idx [- 1 ]
338+ if len (indices_tuple ) == 4 :
339+ a , p , n , jaccard_mat = indices_tuple
340+ keep_mask = lmu .not_self_comparisons (
341+ a , p , s , e , curr_batch_idx , ref_size , ref_is_subset
342+ )
343+ a = a [keep_mask ]
344+ p = p [keep_mask ]
345+ n = n [keep_mask ]
346+ assert len (a ) == len (p ) == len (n )
347+ return a , p , n , jaccard_mat
348+ elif len (indices_tuple ) == 5 :
349+ a1 , p , a2 , n , jaccard_mat = indices_tuple
350+ keep_mask = lmu .not_self_comparisons (
351+ a1 , p , s , e , curr_batch_idx , ref_size , ref_is_subset
352+ )
353+ a1 = a1 [keep_mask ]
354+ p = p [keep_mask ]
355+ assert len (a1 ) == len (p )
356+ assert len (a2 ) == len (n )
357+ return a1 , p , a2 , n , jaccard_mat
358+
359+ # =========================================================================
0 commit comments