diff --git a/ssseg/modules/models/segmentors/idrnet/idrnet.py b/ssseg/modules/models/segmentors/idrnet/idrnet.py index facb2338..3806d1ef 100644 --- a/ssseg/modules/models/segmentors/idrnet/idrnet.py +++ b/ssseg/modules/models/segmentors/idrnet/idrnet.py @@ -323,7 +323,7 @@ def obtainidcontext(self, context, logits, class_relations, intervention_clsids= cls_contexts = torch.stack(cls_contexts) selected_class_relations = torch.cat(selected_class_relations, dim=1) if remove_negative_cls_relation: - selected_class_relations[selected_class_relations <= 0] = -1e16 + selected_class_relations[selected_class_relations <= 0] = -1e4 selected_class_relations = F.softmax(selected_class_relations, dim=1) selected_class_relations_tmp = [] for cls_id in valid_clsids: diff --git a/ssseg/modules/models/segmentors/isnet/semanticlevel.py b/ssseg/modules/models/segmentors/isnet/semanticlevel.py index 390e4592..07279fd6 100644 --- a/ssseg/modules/models/segmentors/isnet/semanticlevel.py +++ b/ssseg/modules/models/segmentors/isnet/semanticlevel.py @@ -31,7 +31,7 @@ def forward(self, x, preds, feats_il): inputs = x batch_size, num_channels, h, w = x.size() num_classes = preds.size(1) - feats_sl = torch.zeros(batch_size, h*w, num_channels).type_as(x) + feats_sl = torch.zeros(batch_size, h*w, num_channels, device=x.device, dtype=torch.float32) for batch_idx in range(batch_size): # (C, H, W), (num_classes, H, W) --> (H*W, C), (H*W, num_classes) feats_iter, preds_iter = x[batch_idx], preds[batch_idx]