From e26b0e2f92ce8c2f1aa921833433f4758d690e67 Mon Sep 17 00:00:00 2001 From: Jordan Stomps Date: Wed, 12 Jul 2023 12:13:57 -0400 Subject: [PATCH 01/10] adding details to NTXentLoss documentation --- docs/losses.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/losses.md b/docs/losses.md index 2af1f905..641da746 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -787,6 +787,13 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf){target=_blank} - [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf){target=_blank} - [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/pdf/2002.05709.pdf){target=_blank} + +In the equation below, loss is computed for each positive pair, `k_+`, in a batch, normalized by all pairs in the batch, `k_i in K`. +For each `embeddings` with `labels` and `ref_emb` with `ref_labels`, positive pair `(embeddings[i], ref_emb[j])` are defined when `labels[i] == ref_labels[j]`. +When `embeddings` and `ref_emb` are augmented versions of each other (e.g. SimCLR), `labels[i] == ref_labels[i]` (see [SelfSupervisedLoss](losses.md#selfsupervisedloss)). +Note that multiple positive pairs can exist if the same label is present multiple times in `labels` and/or `ref_labels`. + +Instead of passing labels (`NTXentLoss(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)`), `indices_tuple` could be passed (see [`pytorch_metric_learning.utils.loss_and_miner_utils.get_all_pairs_indices](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/src/pytorch_metric_learning/utils/loss_and_miner_utils.py)). ```python losses.NTXentLoss(temperature=0.07, **kwargs) ``` @@ -799,6 +806,16 @@ losses.NTXentLoss(temperature=0.07, **kwargs) * **temperature**: This is tau in the above equation. The MoCo paper uses 0.07, while SimCLR uses 0.5. +**Other info:** + +For example, consider `labels = ref_labels = [0, 0, 1, 2]`. Two losses will be computed: + +* Positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`. + +* Positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`. + +Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used. + **Default distance**: - [```CosineSimilarity()```](distances.md#cosinesimilarity) From fac0fe4afafc0128ad7dc587f8497a45284a8e8a Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Thu, 20 Jul 2023 21:40:07 -0400 Subject: [PATCH 02/10] Minor rewording and reorganization of NTXentLoss docs --- docs/losses.md | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/docs/losses.md b/docs/losses.md index 641da746..af3cd76e 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -788,12 +788,26 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse - [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf){target=_blank} - [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/pdf/2002.05709.pdf){target=_blank} -In the equation below, loss is computed for each positive pair, `k_+`, in a batch, normalized by all pairs in the batch, `k_i in K`. -For each `embeddings` with `labels` and `ref_emb` with `ref_labels`, positive pair `(embeddings[i], ref_emb[j])` are defined when `labels[i] == ref_labels[j]`. -When `embeddings` and `ref_emb` are augmented versions of each other (e.g. SimCLR), `labels[i] == ref_labels[i]` (see [SelfSupervisedLoss](losses.md#selfsupervisedloss)). -Note that multiple positive pairs can exist if the same label is present multiple times in `labels` and/or `ref_labels`. +??? "How exactly is the NTXentLoss computed?" + + In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by all positive and negative pairs in the batch that have the same "anchor" embedding (`k_i in K`). + + - What does "anchor" mean? Let's say we have 3 pairs specified by batch indices: (0, 1), (0, 2), (1, 0). The first two pairs start with 0, so they have the same anchor. The third pair has the same indices as the first pair, but the order is different, so it does not have the same anchor. + + Given `embeddings` with corresponding `labels`, positive pairs `(embeddings[i], embeddings[j])` are defined when `labels[i] == labels[j]`. Now let's look at an example loss calculation: + + Consider `labels = [0, 0, 1, 2]`. Two losses will be computed: + + * A positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`. + + * A positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`. + + Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used. + + Note that an anchor can belong to multiple positive pairs if its label is present multiple times in `labels`. + + Are you trying to use `NTXentLoss` for self-supervised learning? Specifically, do you have two sets of embeddings which are derived from data that are augmented versions of each other? If so, you can skip the step of creating the `labels` array, by wrapping `NTXentLoss` with [`SelfSupervisedLoss`](losses.md#selfsupervisedloss). -Instead of passing labels (`NTXentLoss(embeddings, labels, ref_emb=ref_emb, ref_labels=ref_labels)`), `indices_tuple` could be passed (see [`pytorch_metric_learning.utils.loss_and_miner_utils.get_all_pairs_indices](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/src/pytorch_metric_learning/utils/loss_and_miner_utils.py)). ```python losses.NTXentLoss(temperature=0.07, **kwargs) ``` @@ -806,16 +820,6 @@ losses.NTXentLoss(temperature=0.07, **kwargs) * **temperature**: This is tau in the above equation. The MoCo paper uses 0.07, while SimCLR uses 0.5. -**Other info:** - -For example, consider `labels = ref_labels = [0, 0, 1, 2]`. Two losses will be computed: - -* Positive pair of indices `[0, 1]`, with negative pairs of indices `[0, 2], [0, 3]`. - -* Positive pair of indices `[1, 0]`, with negative pairs of indices `[1, 2], [1, 3]`. - -Labels `1`, and `2` do not have positive pairs, and therefore the negative pair of indices `[2, 3]` will not be used. - **Default distance**: - [```CosineSimilarity()```](distances.md#cosinesimilarity) From 36cfd0d9a1a5dea2b22d3b913e96d4281c00a6bb Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Thu, 20 Jul 2023 22:10:38 -0400 Subject: [PATCH 03/10] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f8eb0c7b..728a42d2 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,7 @@ Thanks to the contributors who made pull requests! | [layumi](https://github.com/layumi) | [InstanceLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#instanceloss) | | [NoTody](https://github.com/NoTody) | Helped add `ref_emb` and `ref_labels` to the distributed wrappers. | | [ElisonSherton](https://github.com/ElisonSherton) | Fixed an edge case in ArcFaceLoss. | +| [stompsjo](https://github.com/stompsjo) | Improved documentation for NTXentLoss | | [z1w](https://github.com/z1w) | | | [thinline72](https://github.com/thinline72) | | | [tpanum](https://github.com/tpanum) | | From 8e843863d00014c1b5294a1dd1c245118b74e1dc Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Fri, 21 Jul 2023 15:43:33 -0400 Subject: [PATCH 04/10] minor doc correction --- docs/losses.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/losses.md b/docs/losses.md index af3cd76e..05312656 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -790,7 +790,7 @@ This is also known as InfoNCE, and is a generalization of the [NPairsLoss](losse ??? "How exactly is the NTXentLoss computed?" - In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by all positive and negative pairs in the batch that have the same "anchor" embedding (`k_i in K`). + In the equation below, a loss is computed for each positive pair (`k_+`) in a batch, normalized by itself and all negative pairs in the batch that have the same "anchor" embedding (`k_i in K`). - What does "anchor" mean? Let's say we have 3 pairs specified by batch indices: (0, 1), (0, 2), (1, 0). The first two pairs start with 0, so they have the same anchor. The third pair has the same indices as the first pair, but the order is different, so it does not have the same anchor. From cf82af53dc91ffb54e3bd6c055f78e8920506685 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Tue, 25 Jul 2023 10:59:12 -0400 Subject: [PATCH 05/10] Update README.md --- README.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 728a42d2..846dae25 100644 --- a/README.md +++ b/README.md @@ -18,16 +18,16 @@ ## News +**July 25**: v2.3.0 +- Added [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) +- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0). + **June 18**: v2.2.0 - Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss). - Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss). - See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0). - Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0). -**April 5**: v2.1.0 -- Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) -- Thanks you [interestingzhuo](https://github.com/interestingzhuo). - ## Documentation - [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/) @@ -227,7 +227,7 @@ Thanks to the contributors who made pull requests! | Contributor | Highlights | | -- | -- | -|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) +|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
-[HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) |[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons | |[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper| |[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) | From ac607007dc62666f9de850cd5b8e5694ff0da1c2 Mon Sep 17 00:00:00 2001 From: Kevin Musgrave Date: Tue, 25 Jul 2023 11:00:08 -0400 Subject: [PATCH 06/10] Update README.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 846dae25..9c29da8e 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ Thanks to the contributors who made pull requests! | Contributor | Highlights | | -- | -- | -|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
-[HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) +|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
- [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss) |[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons | |[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper| |[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) | @@ -260,6 +260,7 @@ Thanks to the contributors who made pull requests! | [michaeldeyzel](https://github.com/michaeldeyzel) | | | [HSinger04](https://github.com/HSinger04) | | | [rheum](https://github.com/rheum) | | +| [bot66](https://github.com/bot66) | | From 4c6f2cad642398ed0d26a9759ee6c67f78516ad1 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Sat, 7 Oct 2023 20:12:24 +0800 Subject: [PATCH 07/10] [feat] multi-supcon --- .../losses/__init__.py | 2 + .../losses/multilabel_supcon_loss.py | 86 ++++++++++++ .../losses/xbm_multilabel.py | 132 ++++++++++++++++++ .../utils/multilabel_loss_and_miner_utils.py | 101 ++++++++++++++ tests/losses/test_multilabel_supcon_loss.py | 40 ++++++ 5 files changed, 361 insertions(+) create mode 100644 src/pytorch_metric_learning/losses/multilabel_supcon_loss.py create mode 100644 src/pytorch_metric_learning/losses/xbm_multilabel.py create mode 100644 src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py create mode 100644 tests/losses/test_multilabel_supcon_loss.py diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index a0ba7407..d3b98c94 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -35,3 +35,5 @@ from .triplet_margin_loss import TripletMarginLoss from .tuplet_margin_loss import TupletMarginLoss from .vicreg_loss import VICRegLoss +from .multilabel_supcon_loss import MultiSupConLoss +from .xbm_multilabel import CrossBatchMemory4MultiLabel diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py new file mode 100644 index 00000000..512439f9 --- /dev/null +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -0,0 +1,86 @@ +import torch + +from ..distances import CosineSimilarity +from ..reducers import AvgNonZeroReducer +from ..utils import common_functions as c_f +from ..utils import multilabel_loss_and_miner_utils as mlmu +from ..utils import loss_and_miner_utils as lmu +from .generic_pair_loss import GenericPairLoss + + +# adapted from https://github.com/HobbitLong/SupContrast +class MultiSupConLoss(GenericPairLoss): + def __init__(self, num_classes, temperature=0.1, **kwargs): + super().__init__(mat_based_loss=True, **kwargs) + self.temperature = temperature + self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) + self.num_classes = num_classes + + def _compute_loss(self, mat, pos_mask, neg_mask): + if pos_mask.bool().any() and neg_mask.bool().any(): + # if dealing with actual distances, use negative distances + if not self.distance.is_inverted: + mat = -mat + mat = mat / self.temperature + mat_max, _ = mat.max(dim=1, keepdim=True) + mat = mat - mat_max.detach() # for numerical stability + + denominator = lmu.logsumexp( + mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 + ) + log_prob = mat - denominator + mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / ( + pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) + ) + + return { + "loss": { + "losses": -mean_log_prob_pos, + "indices": c_f.torch_arange_from_size(mat), + "reduction_type": "element", + } + } + return self.zero_losses() + + def get_default_reducer(self): + return AvgNonZeroReducer() + + def get_default_distance(self): + return CosineSimilarity() + + def mat_based_loss(self, mat, indices_tuple): + a1, p, a2, n = indices_tuple + pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat) + pos_mask[a1, p] = 1 + neg_mask[a2, n] = 1 + return self._compute_loss(mat, pos_mask, neg_mask) + + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + c_f.labels_or_indices_tuple_required(labels, indices_tuple) + indices_tuple = mlmu.convert_to_pairs(indices_tuple, labels, self.num_classes, ref_labels, device=embeddings.device) + if all(len(x) <= 1 for x in indices_tuple): + return self.zero_losses() + mat = self.distance(embeddings, ref_emb) + return self.loss_method(mat, indices_tuple) + + def forward( + self, embeddings, labels=None, indices_tuple=None, ref_emb=None, ref_labels=None + ): + """ + Args: + embeddings: tensor of size (batch_size, embedding_size) + labels: tensor of size (batch_size) + indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) + or size 4 for pairs (anchor1, postives, anchor2, negatives) + Can also be left as None + Returns: the loss + """ + self.reset_stats() + mlmu.check_shapes_multilabels(embeddings, labels) + ref_emb, ref_labels = mlmu.set_ref_emb(embeddings, labels, ref_emb, ref_labels) + loss_dict = self.compute_loss( + embeddings, labels, indices_tuple, ref_emb, ref_labels + ) + self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings) + return self.reducer(loss_dict, embeddings, labels) + diff --git a/src/pytorch_metric_learning/losses/xbm_multilabel.py b/src/pytorch_metric_learning/losses/xbm_multilabel.py new file mode 100644 index 00000000..cf9b9f40 --- /dev/null +++ b/src/pytorch_metric_learning/losses/xbm_multilabel.py @@ -0,0 +1,132 @@ +import torch + +from ..utils import common_functions as c_f +# replace the functions of loss_and_miner_utils by multisupcon's +from ..utils import multilabel_loss_and_miner_utils as mlmu +from ..utils import loss_and_miner_utils as lmu +from ..utils.module_with_records import ModuleWithRecords +from .base_loss_wrapper import BaseLossWrapper + + +class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords): + def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs): + super().__init__(loss=loss, **kwargs) + self.loss = loss + self.miner = miner + self.embedding_size = embedding_size + self.memory_size = memory_size + self.num_classes = loss.num_classes + self.reset_queue() + self.add_to_recordable_attributes( + list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False + ) + + @staticmethod + def supported_losses(): + return [ + "MultiSupConLoss" + ] + + @classmethod + def check_loss_support(cls, loss_name): + if loss_name not in cls.supported_losses(): + raise Exception(f"CrossBatchMemory not supported for {loss_name}") + + def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): + if indices_tuple is not None and enqueue_mask is not None: + raise ValueError("indices_tuple and enqueue_mask are mutually exclusive") + if enqueue_mask is not None: + assert len(enqueue_mask) == len(embeddings) + else: + assert len(embeddings) <= len(self.embedding_memory) + self.reset_stats() + device = embeddings.device + self.embedding_memory = c_f.to_device( + self.embedding_memory, device=device, dtype=embeddings.dtype + ) + + if enqueue_mask is not None: + emb_for_queue = embeddings[enqueue_mask] + labels_for_queue = labels[enqueue_mask] + embeddings = embeddings[~enqueue_mask] + labels = labels[~enqueue_mask] + do_remove_self_comparisons = False + else: + emb_for_queue = embeddings + labels_for_queue = labels + do_remove_self_comparisons = True + + queue_batch_size = len(emb_for_queue) + self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size) + + if not self.has_been_filled: + E_mem = self.embedding_memory[: self.queue_idx] + L_mem = self.label_memory[: self.queue_idx] + else: + E_mem = self.embedding_memory + L_mem = self.label_memory + indices_tuple = self.create_indices_tuple( + embeddings, + labels, + E_mem, + L_mem, + indices_tuple, + do_remove_self_comparisons, + ) + loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem) + return loss + + def add_to_memory(self, embeddings, labels, batch_size): + self.curr_batch_idx = ( + torch.arange( + self.queue_idx, self.queue_idx + batch_size + ) + % self.memory_size + ) + self.embedding_memory[self.curr_batch_idx] = embeddings.detach() + # self.label_memory[self.curr_batch_idx] = labels + for i in range(len(self.curr_batch_idx)): + self.label_memory[self.curr_batch_idx[i]] = labels[i] + prev_queue_idx = self.queue_idx + self.queue_idx = (self.queue_idx + batch_size) % self.memory_size + if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx): + self.has_been_filled = True + + def create_indices_tuple( + self, + embeddings, + labels, + E_mem, + L_mem, + input_indices_tuple, + do_remove_self_comparisons, + ): + if self.miner: + indices_tuple = self.miner(embeddings, labels, E_mem, L_mem) + else: + indices_tuple = mlmu.get_all_pairs_indices(labels, self.num_classes, L_mem) + if do_remove_self_comparisons: + indices_tuple = lmu.remove_self_comparisons( + indices_tuple, self.curr_batch_idx, self.memory_size + ) + + if input_indices_tuple is not None: + if len(input_indices_tuple) == 3 and len(indices_tuple) == 4: + input_indices_tuple = mlmu.convert_to_pairs(input_indices_tuple, labels, self.num_classes) + elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3: + input_indices_tuple = mlmu.convert_to_triplets( + input_indices_tuple, labels + ) + indices_tuple = c_f.concatenate_indices_tuples( + indices_tuple, input_indices_tuple + ) + + return indices_tuple + + def reset_queue(self): + self.register_buffer( + "embedding_memory", torch.zeros(self.memory_size, self.embedding_size) + ) + self.label_memory = [[] for i in range(self.memory_size)] + self.has_been_filled = False + self.queue_idx = 0 diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py new file mode 100644 index 00000000..de08e86d --- /dev/null +++ b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py @@ -0,0 +1,101 @@ +import torch +from . import loss_and_miner_utils as lmu + +def check_shapes_multilabels(embeddings, labels): + if labels is not None and embeddings.shape[0] != len(labels): + raise ValueError("Number of embeddings must equal number of labels") + if labels is not None: + if isinstance(labels[0], list) or isinstance(labels[0], torch.Tensor): + pass + else: + raise ValueError("labels must be a list of 1d tensors or a list of lists") + +def set_ref_emb(embeddings, labels, ref_emb, ref_labels): + if ref_emb is None: + ref_emb, ref_labels = embeddings, labels + check_shapes_multilabels(ref_emb, ref_labels) + return ref_emb, ref_labels + +def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device=None): + """ + This returns anchor-positive and anchor-negative indices, + regardless of what the input indices_tuple is + Args: + indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices + within a batch + labels: a tensor which has the label for each element in a batch + """ + if indices_tuple is None: + return get_all_pairs_indices(labels, num_classes, ref_labels, device=device) + elif len(indices_tuple) == 4: + return indices_tuple + else: + a, p, n = indices_tuple + return a, p, a, n + +def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None): + matches = jaccard(num_classes, labels, ref_labels, device=device) + diffs = matches ^ 1 + if ref_labels is labels: + matches.fill_diagonal_(0) + return matches, diffs + + +def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None): + """ + Given a tensor of labels, this will return 4 tensors. + The first 2 tensors are the indices which form all positive pairs + The second 2 tensors are the indices which form all negative pairs + """ + matches, diffs = get_matches_and_diffs(labels, num_classes, ref_labels, device) + a1_idx, p_idx = torch.where(matches) + a2_idx, n_idx = torch.where(diffs) + return a1_idx, p_idx, a2_idx, n_idx + +def jaccard(n_classes, labels, ref_labels=None, threshold=0.3, device=torch.device("cpu")): + if ref_labels is None: + ref_labels = labels + # convert multilabels to scatter labels + labels1 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in labels] + labels2 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in ref_labels] + # stack and convert to float for calculation convenience + labels1 = torch.stack(labels1).float() + labels2 = torch.stack(labels2).float() + + # compute jaccard similarity + # jaccard = intersection / union + labels1_union = labels1.sum(-1) + labels2_union = labels2.sum(-1) + union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0) + intersection = torch.mm(labels1, labels2.T) + jaccard = intersection / (union - intersection) + + # return indices of jaccard similarity above threshold + label_matrix = torch.where(jaccard > threshold, 1, 0).to(device) + return label_matrix + +def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100): + """ + This returns anchor-positive-negative triplets + regardless of what the input indices_tuple is + """ + if indices_tuple is None: + if t_per_anchor == "all": + return get_all_triplets_indices(labels, ref_labels) + else: + return lmu.get_random_triplet_indices( + labels, ref_labels, t_per_anchor=t_per_anchor + ) + elif len(indices_tuple) == 3: + return indices_tuple + else: + a1, p, a2, n = indices_tuple + p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2) + return a1[p_idx], p[p_idx], n[n_idx] + + +def get_all_triplets_indices(labels, ref_labels=None): + matches, diffs = get_matches_and_diffs(labels, ref_labels) + triplets = matches.unsqueeze(2) * diffs.unsqueeze(1) + return torch.where(triplets) + diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py new file mode 100644 index 00000000..cca3c61a --- /dev/null +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -0,0 +1,40 @@ +import unittest + +import torch +import numpy as np +import random + +from pytorch_metric_learning.losses import ( + MultiSupConLoss, + CrossBatchMemory4MultiLabel +) + +class TestMultiSupConLoss(unittest.TestCase): + def test_multi_supcon_loss(self): + n_cls = 10 + n_samples = 16 + n_dim = 256 + loss_func = MultiSupConLoss(num_classes=10) + xbm_loss_func = CrossBatchMemory4MultiLabel(loss_func, n_dim, memory_size=128) + + # # test float32 and float64 + # for dtype in [torch.float32, torch.float64]: + # embeddings = torch.randn(n_samples, n_dim, dtype=dtype) + # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + # loss = loss_func(embeddings, labels) + # self.assertTrue(loss >= 0) + + # # test cuda and cpu + # for device in [torch.device("cpu"),torch.device("cuda")]: + # embeddings = torch.randn(n_samples, n_dim, dtype=dtype, device=device) + # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + # loss = loss_func(embeddings, labels) + # self.assertTrue(loss >= 0) + + # test xbm + batchs = 10 + for b in range(batchs): + embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32) + labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] + loss = xbm_loss_func(embeddings, labels) + self.assertTrue(loss >= 0) \ No newline at end of file From 949a45f898610c4d813103caaeda8ac31fb7fb44 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Sun, 8 Oct 2023 14:10:48 +0800 Subject: [PATCH 08/10] Add multi-supcon Add cross-batch memory for multi-supcon Add test cases --- .../losses/__init__.py | 3 +- .../losses/multilabel_supcon_loss.py | 293 +++++++++++++++++- .../losses/xbm_multilabel.py | 132 -------- .../utils/multilabel_loss_and_miner_utils.py | 101 ------ tests/losses/test_multilabel_supcon_loss.py | 158 ++++++++-- 5 files changed, 413 insertions(+), 274 deletions(-) delete mode 100644 src/pytorch_metric_learning/losses/xbm_multilabel.py delete mode 100644 src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index d3b98c94..cfff0813 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -35,5 +35,4 @@ from .triplet_margin_loss import TripletMarginLoss from .tuplet_margin_loss import TupletMarginLoss from .vicreg_loss import VICRegLoss -from .multilabel_supcon_loss import MultiSupConLoss -from .xbm_multilabel import CrossBatchMemory4MultiLabel +from .multilabel_supcon_loss import MultiSupConLoss, CrossBatchMemory4MultiLabel \ No newline at end of file diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index 512439f9..a8e226ab 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -3,20 +3,22 @@ from ..distances import CosineSimilarity from ..reducers import AvgNonZeroReducer from ..utils import common_functions as c_f -from ..utils import multilabel_loss_and_miner_utils as mlmu from ..utils import loss_and_miner_utils as lmu +from ..utils.module_with_records import ModuleWithRecords from .generic_pair_loss import GenericPairLoss - +from .base_loss_wrapper import BaseLossWrapper # adapted from https://github.com/HobbitLong/SupContrast +# modified for multi-supcon class MultiSupConLoss(GenericPairLoss): - def __init__(self, num_classes, temperature=0.1, **kwargs): + def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs): super().__init__(mat_based_loss=True, **kwargs) self.temperature = temperature self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False) self.num_classes = num_classes + self.threshold = threshold - def _compute_loss(self, mat, pos_mask, neg_mask): + def _compute_loss(self, mat, pos_mask, neg_mask, multi_val): if pos_mask.bool().any() and neg_mask.bool().any(): # if dealing with actual distances, use negative distances if not self.distance.is_inverted: @@ -29,7 +31,7 @@ def _compute_loss(self, mat, pos_mask, neg_mask): mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1 ) log_prob = mat - denominator - mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / ( + mean_log_prob_pos = (multi_val * log_prob * pos_mask).sum(dim=1) / ( pos_mask.sum(dim=1) + c_f.small_val(mat.dtype) ) @@ -48,16 +50,22 @@ def get_default_reducer(self): def get_default_distance(self): return CosineSimilarity() + # ==== class methods below are overriden for adaptability to multi-supcon ==== + def mat_based_loss(self, mat, indices_tuple): - a1, p, a2, n = indices_tuple + a1, p, a2, n, jaccard_mat = indices_tuple pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat) pos_mask[a1, p] = 1 neg_mask[a2, n] = 1 - return self._compute_loss(mat, pos_mask, neg_mask) + return self._compute_loss(mat, pos_mask, neg_mask, jaccard_mat) def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): c_f.labels_or_indices_tuple_required(labels, indices_tuple) - indices_tuple = mlmu.convert_to_pairs(indices_tuple, labels, self.num_classes, ref_labels, device=embeddings.device) + indices_tuple = convert_to_pairs( + indices_tuple, + labels, + ref_labels, + threshold=self.threshold) if all(len(x) <= 1 for x in indices_tuple): return self.zero_losses() mat = self.distance(embeddings, ref_emb) @@ -76,11 +84,276 @@ def forward( Returns: the loss """ self.reset_stats() - mlmu.check_shapes_multilabels(embeddings, labels) - ref_emb, ref_labels = mlmu.set_ref_emb(embeddings, labels, ref_emb, ref_labels) + check_shapes_multilabels(embeddings, labels) + ref_emb, ref_labels = set_ref_emb(embeddings, labels, ref_emb, ref_labels) loss_dict = self.compute_loss( embeddings, labels, indices_tuple, ref_emb, ref_labels ) self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings) return self.reducer(loss_dict, embeddings, labels) + # ========================================================================= + + +# ================== cross batch memory for multi-supcon ================== +class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords): + def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs): + super().__init__(loss=loss, **kwargs) + self.loss = loss + self.miner = miner + self.embedding_size = embedding_size + self.memory_size = memory_size + self.num_classes = loss.num_classes + self.reset_queue() + self.add_to_recordable_attributes( + list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False + ) + + @staticmethod + def supported_losses(): + return [ + "MultiSupConLoss" + ] + + @classmethod + def check_loss_support(cls, loss_name): + if loss_name not in cls.supported_losses(): + raise Exception(f"CrossBatchMemory not supported for {loss_name}") + + def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): + if indices_tuple is not None and enqueue_mask is not None: + raise ValueError("indices_tuple and enqueue_mask are mutually exclusive") + if enqueue_mask is not None: + assert len(enqueue_mask) == len(embeddings) + else: + assert len(embeddings) <= len(self.embedding_memory) + self.reset_stats() + device = embeddings.device + labels = c_f.to_device(labels, device=device) + self.embedding_memory = c_f.to_device( + self.embedding_memory, device=device, dtype=embeddings.dtype + ) + self.label_memory = c_f.to_device( + self.label_memory, device=device, dtype=labels.dtype + ) + + if enqueue_mask is not None: + emb_for_queue = embeddings[enqueue_mask] + labels_for_queue = labels[enqueue_mask] + embeddings = embeddings[~enqueue_mask] + labels = labels[~enqueue_mask] + do_remove_self_comparisons = False + else: + emb_for_queue = embeddings + labels_for_queue = labels + do_remove_self_comparisons = True + + queue_batch_size = len(emb_for_queue) + self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size) + + if not self.has_been_filled: + E_mem = self.embedding_memory[: self.queue_idx] + L_mem = self.label_memory[: self.queue_idx] + else: + E_mem = self.embedding_memory + L_mem = self.label_memory + + indices_tuple = self.create_indices_tuple( + embeddings, + labels, + E_mem, + L_mem, + indices_tuple, + do_remove_self_comparisons, + ) + loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem) + return loss + + def add_to_memory(self, embeddings, labels, batch_size): + self.curr_batch_idx = ( + torch.arange( + self.queue_idx, self.queue_idx + batch_size, device=labels.device + ) + % self.memory_size + ) + self.embedding_memory[self.curr_batch_idx] = embeddings.detach() + self.label_memory[self.curr_batch_idx] = labels.detach() + prev_queue_idx = self.queue_idx + self.queue_idx = (self.queue_idx + batch_size) % self.memory_size + if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx): + self.has_been_filled = True + + def create_indices_tuple( + self, + embeddings, + labels, + E_mem, + L_mem, + input_indices_tuple, + do_remove_self_comparisons, + ): + if self.miner: + indices_tuple = self.miner(embeddings, labels, E_mem, L_mem) + else: + indices_tuple = get_all_pairs_indices(labels, L_mem) + + if do_remove_self_comparisons: + indices_tuple = remove_self_comparisons( + indices_tuple, self.curr_batch_idx, self.memory_size + ) + + if input_indices_tuple is not None: + if len(input_indices_tuple) == 3 and len(indices_tuple) == 4: + input_indices_tuple = convert_to_pairs(input_indices_tuple, labels) + elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3: + input_indices_tuple = convert_to_triplets( + input_indices_tuple, labels + ) + indices_tuple = c_f.concatenate_indices_tuples( + indices_tuple, input_indices_tuple + ) + + return indices_tuple + + def reset_queue(self): + self.register_buffer( + "embedding_memory", torch.zeros(self.memory_size, self.embedding_size) + ) + self.register_buffer( + "label_memory", torch.zeros(self.memory_size, self.num_classes) + ) + self.has_been_filled = False + self.queue_idx = 0 + +# ========================================================================= + +# compute jaccard similarity +def jaccard(labels, ref_labels=None): + if ref_labels is None: + ref_labels = labels + + labels1 = labels.float() + labels2 = ref_labels.float() + + # compute jaccard similarity + # jaccard = intersection / union + labels1_union = labels1.sum(-1) + labels2_union = labels2.sum(-1) + union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0) + intersection = torch.mm(labels1, labels2.T) + jaccard_matrix = intersection / (union - intersection) + + # return indices of jaccard similarity above threshold + return jaccard_matrix + +# ====== methods below are overriden for adaptability to multi-supcon ====== + +# use jaccard similarity to get matches +def get_matches_and_diffs(labels, ref_labels=None, threshold=0.3): + if ref_labels is None: + ref_labels = labels + jaccard_matrix = jaccard(labels, ref_labels) + matches = torch.where(jaccard_matrix > threshold, 1, 0) + diffs = matches ^ 1 + if ref_labels is labels: + matches.fill_diagonal_(0) + return matches, diffs, jaccard_matrix + +def check_shapes_multilabels(embeddings, labels): + if labels is not None and embeddings.shape[0] != labels.shape[0]: + raise ValueError("Number of embeddings must equal number of labels") + if labels is not None and labels.ndim != 2: + raise ValueError("labels must be a 1D tensor of shape (batch_size,)") + + +def set_ref_emb(embeddings, labels, ref_emb, ref_labels): + if ref_emb is None: + ref_emb, ref_labels = embeddings, labels + check_shapes_multilabels(ref_emb, ref_labels) + return ref_emb, ref_labels + + +def convert_to_pairs(indices_tuple, labels, ref_labels=None, threshold=0.3): + """ + This returns anchor-positive and anchor-negative indices, + regardless of what the input indices_tuple is + Args: + indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices + within a batch + labels: a tensor which has the label for each element in a batch + """ + if indices_tuple is None: + return get_all_pairs_indices(labels, ref_labels, threshold=threshold) + elif len(indices_tuple) == 5: + return indices_tuple + else: + a, p, n, jaccard_mat = indices_tuple + return a, p, a, n,jaccard_mat + + +def get_all_pairs_indices(labels, ref_labels=None, threshold=0.3): + """ + Given a tensor of labels, this will return 4 tensors. + The first 2 tensors are the indices which form all positive pairs + The second 2 tensors are the indices which form all negative pairs + """ + matches, diffs, multi_val = get_matches_and_diffs(labels, ref_labels, threshold=threshold) + a1_idx, p_idx = torch.where(matches) + a2_idx, n_idx = torch.where(diffs) + return a1_idx, p_idx, a2_idx, n_idx, multi_val + + +def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100): + """ + This returns anchor-positive-negative triplets + regardless of what the input indices_tuple is + """ + if indices_tuple is None: + if t_per_anchor == "all": + return get_all_triplets_indices(labels, ref_labels) + else: + return lmu.get_random_triplet_indices( + labels, ref_labels, t_per_anchor=t_per_anchor + ) + elif len(indices_tuple) == 3: + return indices_tuple + else: + a1, p, a2, n = indices_tuple + p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2) + return a1[p_idx], p[p_idx], n[n_idx] + + +def get_all_triplets_indices(labels, ref_labels=None): + matches, diffs = get_matches_and_diffs(labels, ref_labels) + triplets = matches.unsqueeze(2) * diffs.unsqueeze(1) + return torch.where(triplets) + + +def remove_self_comparisons( + indices_tuple, curr_batch_idx, ref_size, ref_is_subset=False +): + # remove self-comparisons + assert len(indices_tuple) in [4, 5] + s, e = curr_batch_idx[0], curr_batch_idx[-1] + if len(indices_tuple) == 4: + a, p, n, jaccard_mat = indices_tuple + keep_mask = lmu.not_self_comparisons( + a, p, s, e, curr_batch_idx, ref_size, ref_is_subset + ) + a = a[keep_mask] + p = p[keep_mask] + n = n[keep_mask] + assert len(a) == len(p) == len(n) + return a, p, n, jaccard_mat + elif len(indices_tuple) == 5: + a1, p, a2, n, jaccard_mat = indices_tuple + keep_mask = lmu.not_self_comparisons( + a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset + ) + a1 = a1[keep_mask] + p = p[keep_mask] + assert len(a1) == len(p) + assert len(a2) == len(n) + return a1, p, a2, n, jaccard_mat + +# ========================================================================= \ No newline at end of file diff --git a/src/pytorch_metric_learning/losses/xbm_multilabel.py b/src/pytorch_metric_learning/losses/xbm_multilabel.py deleted file mode 100644 index cf9b9f40..00000000 --- a/src/pytorch_metric_learning/losses/xbm_multilabel.py +++ /dev/null @@ -1,132 +0,0 @@ -import torch - -from ..utils import common_functions as c_f -# replace the functions of loss_and_miner_utils by multisupcon's -from ..utils import multilabel_loss_and_miner_utils as mlmu -from ..utils import loss_and_miner_utils as lmu -from ..utils.module_with_records import ModuleWithRecords -from .base_loss_wrapper import BaseLossWrapper - - -class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords): - def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs): - super().__init__(loss=loss, **kwargs) - self.loss = loss - self.miner = miner - self.embedding_size = embedding_size - self.memory_size = memory_size - self.num_classes = loss.num_classes - self.reset_queue() - self.add_to_recordable_attributes( - list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False - ) - - @staticmethod - def supported_losses(): - return [ - "MultiSupConLoss" - ] - - @classmethod - def check_loss_support(cls, loss_name): - if loss_name not in cls.supported_losses(): - raise Exception(f"CrossBatchMemory not supported for {loss_name}") - - def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None): - if indices_tuple is not None and enqueue_mask is not None: - raise ValueError("indices_tuple and enqueue_mask are mutually exclusive") - if enqueue_mask is not None: - assert len(enqueue_mask) == len(embeddings) - else: - assert len(embeddings) <= len(self.embedding_memory) - self.reset_stats() - device = embeddings.device - self.embedding_memory = c_f.to_device( - self.embedding_memory, device=device, dtype=embeddings.dtype - ) - - if enqueue_mask is not None: - emb_for_queue = embeddings[enqueue_mask] - labels_for_queue = labels[enqueue_mask] - embeddings = embeddings[~enqueue_mask] - labels = labels[~enqueue_mask] - do_remove_self_comparisons = False - else: - emb_for_queue = embeddings - labels_for_queue = labels - do_remove_self_comparisons = True - - queue_batch_size = len(emb_for_queue) - self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size) - - if not self.has_been_filled: - E_mem = self.embedding_memory[: self.queue_idx] - L_mem = self.label_memory[: self.queue_idx] - else: - E_mem = self.embedding_memory - L_mem = self.label_memory - indices_tuple = self.create_indices_tuple( - embeddings, - labels, - E_mem, - L_mem, - indices_tuple, - do_remove_self_comparisons, - ) - loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem) - return loss - - def add_to_memory(self, embeddings, labels, batch_size): - self.curr_batch_idx = ( - torch.arange( - self.queue_idx, self.queue_idx + batch_size - ) - % self.memory_size - ) - self.embedding_memory[self.curr_batch_idx] = embeddings.detach() - # self.label_memory[self.curr_batch_idx] = labels - for i in range(len(self.curr_batch_idx)): - self.label_memory[self.curr_batch_idx[i]] = labels[i] - prev_queue_idx = self.queue_idx - self.queue_idx = (self.queue_idx + batch_size) % self.memory_size - if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx): - self.has_been_filled = True - - def create_indices_tuple( - self, - embeddings, - labels, - E_mem, - L_mem, - input_indices_tuple, - do_remove_self_comparisons, - ): - if self.miner: - indices_tuple = self.miner(embeddings, labels, E_mem, L_mem) - else: - indices_tuple = mlmu.get_all_pairs_indices(labels, self.num_classes, L_mem) - if do_remove_self_comparisons: - indices_tuple = lmu.remove_self_comparisons( - indices_tuple, self.curr_batch_idx, self.memory_size - ) - - if input_indices_tuple is not None: - if len(input_indices_tuple) == 3 and len(indices_tuple) == 4: - input_indices_tuple = mlmu.convert_to_pairs(input_indices_tuple, labels, self.num_classes) - elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3: - input_indices_tuple = mlmu.convert_to_triplets( - input_indices_tuple, labels - ) - indices_tuple = c_f.concatenate_indices_tuples( - indices_tuple, input_indices_tuple - ) - - return indices_tuple - - def reset_queue(self): - self.register_buffer( - "embedding_memory", torch.zeros(self.memory_size, self.embedding_size) - ) - self.label_memory = [[] for i in range(self.memory_size)] - self.has_been_filled = False - self.queue_idx = 0 diff --git a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py b/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py deleted file mode 100644 index de08e86d..00000000 --- a/src/pytorch_metric_learning/utils/multilabel_loss_and_miner_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch -from . import loss_and_miner_utils as lmu - -def check_shapes_multilabels(embeddings, labels): - if labels is not None and embeddings.shape[0] != len(labels): - raise ValueError("Number of embeddings must equal number of labels") - if labels is not None: - if isinstance(labels[0], list) or isinstance(labels[0], torch.Tensor): - pass - else: - raise ValueError("labels must be a list of 1d tensors or a list of lists") - -def set_ref_emb(embeddings, labels, ref_emb, ref_labels): - if ref_emb is None: - ref_emb, ref_labels = embeddings, labels - check_shapes_multilabels(ref_emb, ref_labels) - return ref_emb, ref_labels - -def convert_to_pairs(indices_tuple, labels, num_classes, ref_labels=None, device=None): - """ - This returns anchor-positive and anchor-negative indices, - regardless of what the input indices_tuple is - Args: - indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices - within a batch - labels: a tensor which has the label for each element in a batch - """ - if indices_tuple is None: - return get_all_pairs_indices(labels, num_classes, ref_labels, device=device) - elif len(indices_tuple) == 4: - return indices_tuple - else: - a, p, n = indices_tuple - return a, p, a, n - -def get_matches_and_diffs(labels, num_classes, ref_labels=None, device=None): - matches = jaccard(num_classes, labels, ref_labels, device=device) - diffs = matches ^ 1 - if ref_labels is labels: - matches.fill_diagonal_(0) - return matches, diffs - - -def get_all_pairs_indices(labels, num_classes, ref_labels=None, device=None): - """ - Given a tensor of labels, this will return 4 tensors. - The first 2 tensors are the indices which form all positive pairs - The second 2 tensors are the indices which form all negative pairs - """ - matches, diffs = get_matches_and_diffs(labels, num_classes, ref_labels, device) - a1_idx, p_idx = torch.where(matches) - a2_idx, n_idx = torch.where(diffs) - return a1_idx, p_idx, a2_idx, n_idx - -def jaccard(n_classes, labels, ref_labels=None, threshold=0.3, device=torch.device("cpu")): - if ref_labels is None: - ref_labels = labels - # convert multilabels to scatter labels - labels1 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in labels] - labels2 = [torch.nn.functional.one_hot(torch.Tensor(label).long(), n_classes).sum(0) for label in ref_labels] - # stack and convert to float for calculation convenience - labels1 = torch.stack(labels1).float() - labels2 = torch.stack(labels2).float() - - # compute jaccard similarity - # jaccard = intersection / union - labels1_union = labels1.sum(-1) - labels2_union = labels2.sum(-1) - union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0) - intersection = torch.mm(labels1, labels2.T) - jaccard = intersection / (union - intersection) - - # return indices of jaccard similarity above threshold - label_matrix = torch.where(jaccard > threshold, 1, 0).to(device) - return label_matrix - -def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100): - """ - This returns anchor-positive-negative triplets - regardless of what the input indices_tuple is - """ - if indices_tuple is None: - if t_per_anchor == "all": - return get_all_triplets_indices(labels, ref_labels) - else: - return lmu.get_random_triplet_indices( - labels, ref_labels, t_per_anchor=t_per_anchor - ) - elif len(indices_tuple) == 3: - return indices_tuple - else: - a1, p, a2, n = indices_tuple - p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2) - return a1[p_idx], p[p_idx], n[n_idx] - - -def get_all_triplets_indices(labels, ref_labels=None): - matches, diffs = get_matches_and_diffs(labels, ref_labels) - triplets = matches.unsqueeze(2) * diffs.unsqueeze(1) - return torch.where(triplets) - diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index cca3c61a..fc82fc66 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -2,39 +2,139 @@ import torch import numpy as np -import random from pytorch_metric_learning.losses import ( MultiSupConLoss, CrossBatchMemory4MultiLabel ) +from ..zzz_testing_utils.testing_utils import angle_to_coord + +from .. import TEST_DEVICE, TEST_DTYPES class TestMultiSupConLoss(unittest.TestCase): - def test_multi_supcon_loss(self): - n_cls = 10 - n_samples = 16 - n_dim = 256 - loss_func = MultiSupConLoss(num_classes=10) - xbm_loss_func = CrossBatchMemory4MultiLabel(loss_func, n_dim, memory_size=128) - - # # test float32 and float64 - # for dtype in [torch.float32, torch.float64]: - # embeddings = torch.randn(n_samples, n_dim, dtype=dtype) - # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - # loss = loss_func(embeddings, labels) - # self.assertTrue(loss >= 0) - - # # test cuda and cpu - # for device in [torch.device("cpu"),torch.device("cuda")]: - # embeddings = torch.randn(n_samples, n_dim, dtype=dtype, device=device) - # labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - # loss = loss_func(embeddings, labels) - # self.assertTrue(loss >= 0) - - # test xbm - batchs = 10 - for b in range(batchs): - embeddings = torch.randn(n_samples, n_dim, dtype=torch.float32) - labels = [random.sample(range(n_cls), np.random.randint(1, 4)) for i in range(n_samples)] - loss = xbm_loss_func(embeddings, labels) - self.assertTrue(loss >= 0) \ No newline at end of file + def __init__(self, methodName: str = "runTest") -> None: + super().__init__(methodName) + self.n_cls = 3 + self.n_samples = 4 + self.n_dim = 3 + self.n_batchs = 10 + self.xbm_max_size = 1024 + + # multi_supcon + self.loss_func = MultiSupConLoss( + num_classes=self.n_cls, + temperature=0.07, + threshold=0.3) + + # xbm + self.xbm_loss_func = CrossBatchMemory4MultiLabel( + self.loss_func, + self.n_dim, + memory_size=self.xbm_max_size) + # test cases + self.embeddings = torch.tensor([[0.1, 0.3, 0.1], + [0.23, -0.2, -0.1], + [0.1, -0.16, 0.1], + [0.13, -0.13, 0.2]]) + self.labels = torch.tensor([[1,0,1], [1,0,0], [0,1,1], [0,1,0]]) + + # the gt values are obtained by running the code + # (https://github.com/WolodjaZ/MultiSupContrast/blob/main/losses.py) + + # multi_supcon test cases + self.test_multisupcon_val_gt = { + torch.float16: 3.2836, + torch.float32: 3.2874, + torch.float64: 3.2874, + } + # xbm test cases + self.test_xbm_multisupcon_val_gt = { + torch.float16: [3.2836, 4.3792, 4.4588, 4.5741, 4.6831, 4.7809, 4.8682, 4.9465, 5.0174, 5.0819], + torch.float32: [3.2874, 4.3779, 4.4577, 4.5730, 4.6820, 4.7798, 4.8671, 4.9455, 5.0163, 5.0808], + torch.float64: [3.2874, 4.3779, 4.4577, 4.5730, 4.6820, 4.7798, 4.8671, 4.9455, 5.0163, 5.0808,] + } + + + def test_multisupcon_val(self): + for dtype in TEST_DTYPES: + for device in ["cpu", "cuda"]: + # skip float16 on cpu + if device == "cpu" and dtype == torch.float16: + continue + embedding = self.embeddings.to(device).to(dtype) + label = self.labels.to(device).to(dtype) + loss = self.loss_func(embedding, label) + loss = loss.to("cpu") + self.assertTrue(np.isclose( + loss.item(), + self.test_multisupcon_val_gt[dtype], + atol=1e-2 if dtype == torch.float16 else 1e-4)) + + + def test_xbm_multisupcon_val(self): + # test xbm with scatter labels + for dtype in TEST_DTYPES: + for device in ["cpu", "cuda"]: + # skip float16 on cpu + if device == "cpu" and dtype == torch.float16: + continue + self.xbm_loss_func.reset_queue() + for b in range(self.n_batchs): + embedding = self.embeddings.to(device).to(dtype) + label = self.labels.to(device).to(dtype) + loss = self.xbm_loss_func(embedding, label) + loss = loss.to("cpu") + print(loss, self.test_xbm_multisupcon_val_gt[dtype][b], dtype) + self.assertTrue(np.isclose( + loss.item(), + self.test_xbm_multisupcon_val_gt[dtype][b], + atol=1e-2 if dtype == torch.float16 else 1e-4)) + + def test_with_no_valid_pairs(self): + for dtype in TEST_DTYPES: + embedding_angles = [0] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([[0]]) + loss = self.loss_func(embeddings, labels) + loss.backward() + self.assertEqual(loss, 0) + + def test_(self): + for dtype in TEST_DTYPES: + embedding_angles = [0] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([[0]]) + loss = self.loss_func(embeddings, labels) + loss.backward() + self.assertEqual(loss, 0) + + + def test_backward(self): + for dtype in TEST_DTYPES: + embedding_angles = list(range(0, 180, 20))[:4] + embeddings = torch.tensor( + [angle_to_coord(a) for a in embedding_angles], + requires_grad=True, + dtype=dtype, + ).to( + TEST_DEVICE + ) # 2D embeddings + labels = torch.LongTensor([[0, 0, 1, 0, 1, 0, 0], + [1, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 1]]).to(TEST_DEVICE) + + loss = self.loss_func(embeddings, labels) + loss.backward() \ No newline at end of file From 855baa6ae9530d429325f09861932fc4e08ac291 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:26:51 +0800 Subject: [PATCH 09/10] Add docstring --- .../losses/multilabel_supcon_loss.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py index a8e226ab..c8293918 100644 --- a/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py +++ b/src/pytorch_metric_learning/losses/multilabel_supcon_loss.py @@ -11,6 +11,22 @@ # adapted from https://github.com/HobbitLong/SupContrast # modified for multi-supcon class MultiSupConLoss(GenericPairLoss): + """ + Args: + num_classes: number of classes + temperature: temperature for scaling the similarity matrix + threshold: threshold for jaccard similarity + + Inputs: + embeddings: tensor of size (batch_size, embedding_size) + labels: tensor of size (batch_size, num_classes) + each row is a binary vector of size num_classes that only has 1s for the positive + labels, and 0s for the negative labels + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) + Can also be left as None + ref_emb: tensor of size (batch_size, embedding_size) + """ def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs): super().__init__(mat_based_loss=True, **kwargs) self.temperature = temperature @@ -77,10 +93,13 @@ def forward( """ Args: embeddings: tensor of size (batch_size, embedding_size) - labels: tensor of size (batch_size) - indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) - or size 4 for pairs (anchor1, postives, anchor2, negatives) + labels: tensor of size (batch_size, num_classes) + each row is a binary vector of size num_classes that only has 1s for the positive + labels, and 0s for the negative labels + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) Can also be left as None + ref_emb: tensor of size (batch_size, embedding_size) Returns: the loss """ self.reset_stats() From e4b1f39586cfc76572ea5d875283c152638e3f39 Mon Sep 17 00:00:00 2001 From: Qibin Liang <46101145+PhyseChan@users.noreply.github.com> Date: Wed, 11 Oct 2023 14:33:40 +0800 Subject: [PATCH 10/10] Update test case --- tests/losses/test_multilabel_supcon_loss.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/tests/losses/test_multilabel_supcon_loss.py b/tests/losses/test_multilabel_supcon_loss.py index fc82fc66..c0584448 100644 --- a/tests/losses/test_multilabel_supcon_loss.py +++ b/tests/losses/test_multilabel_supcon_loss.py @@ -90,6 +90,7 @@ def test_xbm_multisupcon_val(self): self.test_xbm_multisupcon_val_gt[dtype][b], atol=1e-2 if dtype == torch.float16 else 1e-4)) + def test_with_no_valid_pairs(self): for dtype in TEST_DTYPES: embedding_angles = [0] @@ -105,21 +106,6 @@ def test_with_no_valid_pairs(self): loss.backward() self.assertEqual(loss, 0) - def test_(self): - for dtype in TEST_DTYPES: - embedding_angles = [0] - embeddings = torch.tensor( - [angle_to_coord(a) for a in embedding_angles], - requires_grad=True, - dtype=dtype, - ).to( - TEST_DEVICE - ) # 2D embeddings - labels = torch.LongTensor([[0]]) - loss = self.loss_func(embeddings, labels) - loss.backward() - self.assertEqual(loss, 0) - def test_backward(self): for dtype in TEST_DTYPES: