Skip to content

Commit f10d13f

Browse files
sieu-nclaude
andcommitted
Fix fp16 compatibility issues in multiple segmentation models
- IDRNet: Change negative value from -1e16 to -1e4 to prevent fp16 overflow - ISNet: Fix tensor creation to ensure proper device and dtype handling - OCRNet: Fix spatial gather module input size mismatch by interpolating seg_logits_aux These changes ensure models work correctly with fp16 (half precision) training. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 3e0c352 commit f10d13f

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

ssseg/modules/models/segmentors/idrnet/idrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def obtainidcontext(self, context, logits, class_relations, intervention_clsids=
323323
cls_contexts = torch.stack(cls_contexts)
324324
selected_class_relations = torch.cat(selected_class_relations, dim=1)
325325
if remove_negative_cls_relation:
326-
selected_class_relations[selected_class_relations <= 0] = -1e16
326+
selected_class_relations[selected_class_relations <= 0] = -1e4
327327
selected_class_relations = F.softmax(selected_class_relations, dim=1)
328328
selected_class_relations_tmp = []
329329
for cls_id in valid_clsids:

ssseg/modules/models/segmentors/isnet/semanticlevel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def forward(self, x, preds, feats_il):
3131
inputs = x
3232
batch_size, num_channels, h, w = x.size()
3333
num_classes = preds.size(1)
34-
feats_sl = torch.zeros(batch_size, h*w, num_channels).type_as(x)
34+
feats_sl = torch.zeros(batch_size, h*w, num_channels, device=x.device, dtype=torch.float32)
3535
for batch_idx in range(batch_size):
3636
# (C, H, W), (num_classes, H, W) --> (H*W, C), (H*W, num_classes)
3737
feats_iter, preds_iter = x[batch_idx], preds[batch_idx]

ssseg/modules/models/segmentors/ocrnet/ocrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def forward(self, data_meta):
5656
# feed to bottleneck
5757
feats = self.bottleneck(backbone_outputs[-1])
5858
# feed to ocr module
59-
context = self.spatial_gather_module(feats, seg_logits_aux)
59+
context = self.spatial_gather_module(feats, F.interpolate(seg_logits_aux, size=feats.shape[2:]))
6060
feats = self.object_context_block(feats, context)
6161
# feed to decoder
6262
seg_logits = self.decoder(feats)

0 commit comments

Comments
 (0)