Skip to content

Commit 4c6f2ca

Browse files
committed
[feat] multi-supcon
1 parent ac60700 commit 4c6f2ca

File tree

5 files changed

+361
-0
lines changed

5 files changed

+361
-0
lines changed

src/pytorch_metric_learning/losses/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,5 @@
3535
from .triplet_margin_loss import TripletMarginLoss
3636
from .tuplet_margin_loss import TupletMarginLoss
3737
from .vicreg_loss import VICRegLoss
38+
from .multilabel_supcon_loss import MultiSupConLoss
39+
from .xbm_multilabel import CrossBatchMemory4MultiLabel
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
3+
from ..distances import CosineSimilarity
4+
from ..reducers import AvgNonZeroReducer
5+
from ..utils import common_functions as c_f
6+
from ..utils import multilabel_loss_and_miner_utils as mlmu
7+
from ..utils import loss_and_miner_utils as lmu
8+
from .generic_pair_loss import GenericPairLoss
9+
10+
11+
# adapted from https://github.com/HobbitLong/SupContrast
12+
class MultiSupConLoss(GenericPairLoss):
13+
def __init__(self, num_classes, temperature=0.1, **kwargs):
14+
super().__init__(mat_based_loss=True, **kwargs)
15+
self.temperature = temperature
16+
self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False)
17+
self.num_classes = num_classes
18+
19+
def _compute_loss(self, mat, pos_mask, neg_mask):
20+
if pos_mask.bool().any() and neg_mask.bool().any():
21+
# if dealing with actual distances, use negative distances
22+
if not self.distance.is_inverted:
23+
mat = -mat
24+
mat = mat / self.temperature
25+
mat_max, _ = mat.max(dim=1, keepdim=True)
26+
mat = mat - mat_max.detach() # for numerical stability
27+
28+
denominator = lmu.logsumexp(
29+
mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1
30+
)
31+
log_prob = mat - denominator
32+
mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / (
33+
pos_mask.sum(dim=1) + c_f.small_val(mat.dtype)
34+
)
35+
36+
return {
37+
"loss": {
38+
"losses": -mean_log_prob_pos,
39+
"indices": c_f.torch_arange_from_size(mat),
40+
"reduction_type": "element",
41+
}
42+
}
43+
return self.zero_losses()
44+
45+
def get_default_reducer(self):
46+
return AvgNonZeroReducer()
47+
48+
def get_default_distance(self):
49+
return CosineSimilarity()
50+
51+
def mat_based_loss(self, mat, indices_tuple):
52+
a1, p, a2, n = indices_tuple
53+
pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat)
54+
pos_mask[a1, p] = 1
55+
neg_mask[a2, n] = 1
56+
return self._compute_loss(mat, pos_mask, neg_mask)
57+
58+
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
59+
c_f.labels_or_indices_tuple_required(labels, indices_tuple)
60+
indices_tuple = mlmu.convert_to_pairs(indices_tuple, labels, self.num_classes, ref_labels, device=embeddings.device)
61+
if all(len(x) <= 1 for x in indices_tuple):
62+
return self.zero_losses()
63+
mat = self.distance(embeddings, ref_emb)
64+
return self.loss_method(mat, indices_tuple)
65+
66+
def forward(
67+
self, embeddings, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None
68+
):
69+
"""
70+
Args:
71+
embeddings: tensor of size (batch_size, embedding_size)
72+
labels: tensor of size (batch_size)
73+
indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives)
74+
or size 4 for pairs (anchor1, postives, anchor2, negatives)
75+
Can also be left as None
76+
Returns: the loss
77+
"""
78+
self.reset_stats()
79+
mlmu.check_shapes_multilabels(embeddings, labels)
80+
ref_emb, ref_labels = mlmu.set_ref_emb(embeddings, labels, ref_emb, ref_labels)
81+
loss_dict = self.compute_loss(
82+
embeddings, labels, indices_tuple, ref_emb, ref_labels
83+
)
84+
self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
85+
return self.reducer(loss_dict, embeddings, labels)
86+
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import torch
2+
3+
from ..utils import common_functions as c_f
4+
# replace the functions of loss_and_miner_utils by multisupcon's
5+
from ..utils import multilabel_loss_and_miner_utils as mlmu
6+
from ..utils import loss_and_miner_utils as lmu
7+
from ..utils.module_with_records import ModuleWithRecords
8+
from .base_loss_wrapper import BaseLossWrapper
9+
10+
11+
class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords):
12+
def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs):
13+
super().__init__(loss=loss, **kwargs)
14+
self.loss = loss
15+
self.miner = miner
16+
self.embedding_size = embedding_size
17+
self.memory_size = memory_size
18+
self.num_classes = loss.num_classes
19+
self.reset_queue()
20+
self.add_to_recordable_attributes(
21+
list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False
22+
)
23+
24+
@staticmethod
25+
def supported_losses():
26+
return [
27+
"MultiSupConLoss"
28+
]
29+
30+
@classmethod
31+
def check_loss_support(cls, loss_name):
32+
if loss_name not in cls.supported_losses():
33+
raise Exception(f"CrossBatchMemory not supported for {loss_name}")
34+
35+
def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None):
36+
if indices_tuple is not None and enqueue_mask is not None:
37+
raise ValueError("indices_tuple and enqueue_mask are mutually exclusive")
38+
if enqueue_mask is not None:
39+
assert len(enqueue_mask) == len(embeddings)
40+
else:
41+
assert len(embeddings) <= len(self.embedding_memory)
42+
self.reset_stats()
43+
device = embeddings.device
44+
self.embedding_memory = c_f.to_device(
45+
self.embedding_memory, device=device, dtype=embeddings.dtype
46+
)
47+
48+
if enqueue_mask is not None:
49+
emb_for_queue = embeddings[enqueue_mask]
50+
labels_for_queue = labels[enqueue_mask]
51+
embeddings = embeddings[~enqueue_mask]
52+
labels = labels[~enqueue_mask]
53+
do_remove_self_comparisons = False
54+
else:
55+
emb_for_queue = embeddings
56+
labels_for_queue = labels
57+
do_remove_self_comparisons = True
58+
59+
queue_batch_size = len(emb_for_queue)
60+
self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size)
61+
62+
if not self.has_been_filled:
63+
E_mem = self.embedding_memory[: self.queue_idx]
64+
L_mem = self.label_memory[: self.queue_idx]
65+
else:
66+
E_mem = self.embedding_memory
67+
L_mem = self.label_memory
68+
indices_tuple = self.create_indices_tuple(
69+
embeddings,
70+
labels,
71+
E_mem,
72+
L_mem,
73+
indices_tuple,
74+
do_remove_self_comparisons,
75+
)
76+
loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem)
77+
return loss
78+
79+
def add_to_memory(self, embeddings, labels, batch_size):
80+
self.curr_batch_idx = (
81+
torch.arange(
82+
self.queue_idx, self.queue_idx + batch_size
83+
)
84+
% self.memory_size
85+
)
86+
self.embedding_memory[self.curr_batch_idx] = embeddings.detach()
87+
# self.label_memory[self.curr_batch_idx] = labels
88+
for i in range(len(self.curr_batch_idx)):
89+
self.label_memory[self.curr_batch_idx[i]] = labels[i]
90+
prev_queue_idx = self.queue_idx
91+
self.queue_idx = (self.queue_idx + batch_size) % self.memory_size
92+
if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx):
93+
self.has_been_filled = True
94+
95+
def create_indices_tuple(
96+
self,
97+
embeddings,
98+
labels,
99+
E_mem,
100+
L_mem,
101+
input_indices_tuple,
102+
do_remove_self_comparisons,
103+
):
104+
if self.miner:
105+
indices_tuple = self.miner(embeddings, labels, E_mem, L_mem)
106+
else:
107+
indices_tuple = mlmu.get_all_pairs_indices(labels, self.num_classes, L_mem)
108+
if do_remove_self_comparisons:
109+
indices_tuple = lmu.remove_self_comparisons(
110+
indices_tuple, self.curr_batch_idx, self.memory_size
111+
)
112+
113+
if input_indices_tuple is not None:
114+
if len(input_indices_tuple) == 3 and len(indices_tuple) == 4:
115+
input_indices_tuple = mlmu.convert_to_pairs(input_indices_tuple, labels, self.num_classes)
116+
elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3:
117+
input_indices_tuple = mlmu.convert_to_triplets(
118+
input_indices_tuple, labels
119+
)
120+
indices_tuple = c_f.concatenate_indices_tuples(
121+
indices_tuple, input_indices_tuple
122+
)
123+
124+
return indices_tuple
125+
126+
def reset_queue(self):
127+
self.register_buffer(
128+
"embedding_memory", torch.zeros(self.memory_size, self.embedding_size)
129+
)
130+
self.label_memory = [[] for i in range(self.memory_size)]
131+
self.has_been_filled = False
132+
self.queue_idx = 0
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
from . import loss_and_miner_utils as lmu
3+
4+
def check_shapes_multilabels(embeddings, labels):
5+
if labels is not None and embeddings.shape[0] != len(labels):
6+
raise ValueError("Number of embeddings must equal number of labels")
7+
if labels is not None:
8+
if isinstance(labels[0], list) or isinstance(labels[0], torch.Tensor):
9+
pass
10+
else:
11+
raise ValueError("labels must be a list of 1d tensors or a list of lists")
12+
13+
def set_ref_emb(embeddings, labels, ref_emb, ref_labels):
14+
if ref_emb is None:
15+
ref_emb, ref_labels = embeddings, labels
16+
check_shapes_multilabels(ref_emb, ref_labels)
17+
return ref_emb, ref_labels
18+
19+
def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device=None):
20+
"""
21+
This returns anchor-positive and anchor-negative indices,
22+
regardless of what the input indices_tuple is
23+
Args:
24+
indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices
25+
within a batch
26+
labels: a tensor which has the label for each element in a batch
27+
"""
28+
if indices_tuple is None:
29+
return get_all_pairs_indices(labels, num_classes, ref_labels, device=device)
30+
elif len(indices_tuple) == 4:
31+
return indices_tuple
32+
else:
33+
a, p, n = indices_tuple
34+
return a, p, a, n
35+
36+
def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None):
37+
matches = jaccard(num_classes, labels, ref_labels, device=device)
38+
diffs = matches ^ 1
39+
if ref_labels is labels:
40+
matches.fill_diagonal_(0)
41+
return matches, diffs
42+
43+
44+
def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None):
45+
"""
46+
Given a tensor of labels, this will return 4 tensors.
47+
The first 2 tensors are the indices which form all positive pairs
48+
The second 2 tensors are the indices which form all negative pairs
49+
"""
50+
matches, diffs = get_matches_and_diffs(labels, num_classes, ref_labels, device)
51+
a1_idx, p_idx = torch.where(matches)
52+
a2_idx, n_idx = torch.where(diffs)
53+
return a1_idx, p_idx, a2_idx, n_idx
54+
55+
def jaccard(n_classes, labels, ref_labels=None, threshold=0.3, device=torch.device("cpu")):
56+
if ref_labels is None:
57+
ref_labels = labels
58+
# convert multilabels to scatter labels
59+
labels1 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in labels]
60+
labels2 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in ref_labels]
61+
# stack and convert to float for calculation convenience
62+
labels1 = torch.stack(labels1).float()
63+
labels2 = torch.stack(labels2).float()
64+
65+
# compute jaccard similarity
66+
# jaccard = intersection / union
67+
labels1_union = labels1.sum(-1)
68+
labels2_union = labels2.sum(-1)
69+
union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0)
70+
intersection = torch.mm(labels1, labels2.T)
71+
jaccard = intersection / (union - intersection)
72+
73+
# return indices of jaccard similarity above threshold
74+
label_matrix = torch.where(jaccard > threshold, 1, 0).to(device)
75+
return label_matrix
76+
77+
def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100):
78+
"""
79+
This returns anchor-positive-negative triplets
80+
regardless of what the input indices_tuple is
81+
"""
82+
if indices_tuple is None:
83+
if t_per_anchor == "all":
84+
return get_all_triplets_indices(labels, ref_labels)
85+
else:
86+
return lmu.get_random_triplet_indices(
87+
labels, ref_labels, t_per_anchor=t_per_anchor
88+
)
89+
elif len(indices_tuple) == 3:
90+
return indices_tuple
91+
else:
92+
a1, p, a2, n = indices_tuple
93+
p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2)
94+
return a1[p_idx], p[p_idx], n[n_idx]
95+
96+
97+
def get_all_triplets_indices(labels, ref_labels=None):
98+
matches, diffs = get_matches_and_diffs(labels, ref_labels)
99+
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
100+
return torch.where(triplets)
101+
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import unittest
2+
3+
import torch
4+
import numpy as np
5+
import random
6+
7+
from pytorch_metric_learning.losses import (
8+
MultiSupConLoss,
9+
CrossBatchMemory4MultiLabel
10+
)
11+
12+
class TestMultiSupConLoss(unittest.TestCase):
13+
def test_multi_supcon_loss(self):
14+
n_cls = 10
15+
n_samples = 16
16+
n_dim = 256
17+
loss_func = MultiSupConLoss(num_classes=10)
18+
xbm_loss_func = CrossBatchMemory4MultiLabel(loss_func, n_dim, memory_size=128)
19+
20+
# # test float32 and float64
21+
# for dtype in [torch.float32, torch.float64]:
22+
# embeddings = torch.randn(n_samples, n_dim, dtype=dtype)
23+
# labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)]
24+
# loss = loss_func(embeddings, labels)
25+
# self.assertTrue(loss >= 0)
26+
27+
# # test cuda and cpu
28+
# for device in [torch.device("cpu"),torch.device("cuda")]:
29+
# embeddings = torch.randn(n_samples, n_dim, dtype=dtype, device=device)
30+
# labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)]
31+
# loss = loss_func(embeddings, labels)
32+
# self.assertTrue(loss >= 0)
33+
34+
# test xbm
35+
batchs = 10
36+
for b in range(batchs):
37+
embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32)
38+
labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)]
39+
loss = xbm_loss_func(embeddings, labels)
40+
self.assertTrue(loss >= 0)

0 commit comments

Comments
 (0)