Skip to content

Commit 81cc03d

Browse files
author
Mark-ZhouWX
committed
add point finetune: prevent duplicate graph compile of grad reducer
1 parent a1929ce commit 81cc03d

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

official/cv/segment-anything/segment_anything/utils/model_wrapper.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ def _build_train_network(self):
209209

210210
# for training only
211211
net.set_train(True)
212+
213+
@ms.jit(compile_once=True)
214+
def grad_reducer_wrapper(grads):
215+
return grad_reducer(grads)
216+
217+
@ms.jit(compile_once=True)
218+
def optimizer_wrapper(grads):
219+
optimizer(grads)
220+
212221
@ms.jit
213222
def forward_point(image, points=None, boxes=None, masks=None,
214223
gt_mask=None, valid_boxes=None,
@@ -248,7 +257,7 @@ def _train_fn(*data_element):
248257
# print(f'get next takes: {s1-s0:.2f}s')
249258
(loss, (mask, iou, low_res_mask)), grads = grad_fn(
250259
input_dict['image'],
251-
ms.mutable(point_and_label),
260+
ms.mutable(point_and_label), # mutable tuple to prevent duplicate graph compiling
252261
None, # box
253262
previous_low_mask,
254263
gt_dict['masks'],
@@ -274,9 +283,9 @@ def _train_fn(*data_element):
274283

275284
# print(f'loss list', loss_list)
276285
t0 = time.time()
277-
grad_accum = grad_reducer(grad_accum)
286+
grad_accum = grad_reducer_wrapper(ms.mutable(grad_accum)) # mutable tuple to prevent duplicate graph compiling
278287
if np.all(grad_finite_list):
279-
optimizer(grad_accum)
288+
optimizer_wrapper(ms.mutable(grad_accum)) # mutable tuple to prevent duplicate graph compiling
280289
else:
281290
print(f'gradient overflow')
282291
t1 = time.time()

0 commit comments

Comments
 (0)