Skip to content

Commit 37588ed

Browse files
author
Mark-ZhouWX
committed
add point finetune: fix bug of valid_boxes
1 parent 51c481e commit 37588ed

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

official/cv/segment-anything/segment_anything/modeling/loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, focal_factor=20.0, dice_factor=1.0, mse_factor=1.0, mask_thre
3232
self.dice_loss = DiceLoss(reduction='none')
3333
self.mse_loss = nn.MSELoss(reduction='none')
3434

35-
def construct(self, pred_mask, pred_iou, gt_mask, valid_boxes=None):
35+
def construct(self, pred_mask, pred_iou, gt_mask, valid_boxes):
3636
"""
3737
get loss, remove dynamic shape assisted with valid_boxes
3838
Args:

official/cv/segment-anything/segment_anything/utils/model_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def forward_point(image, points=None, boxes=None, masks=None,
224224
multimask_output=False, output_best_mask=True, return_low_res_mask=True):
225225
_pred_mask, _pred_iou, _low_res_mask = net(image, points=points, boxes=boxes, masks=masks,
226226
multimask_output=multimask_output, output_best_mask=output_best_mask, return_low_res_mask=return_low_res_mask)
227-
_loss = loss_fn(_pred_mask, _pred_iou, gt_mask=gt_mask, valid_boxes=valid_boxes)
227+
_loss = loss_fn(_pred_mask, _pred_iou, gt_mask, valid_boxes)
228228
return loss_scaler.scale(_loss[0]), (_pred_mask, _pred_iou, _low_res_mask)
229229

230230
def _train_fn(*data_element):

0 commit comments

Comments
 (0)