@@ -209,6 +209,15 @@ def _build_train_network(self):
209209
210210 # for training only
211211 net .set_train (True )
212+
213+ @ms .jit (compile_once = True )
214+ def grad_reducer_wrapper (grads ):
215+ return grad_reducer (grads )
216+
217+ @ms .jit (compile_once = True )
218+ def optimizer_wrapper (grads ):
219+ optimizer (grads )
220+
212221 @ms .jit
213222 def forward_point (image , points = None , boxes = None , masks = None ,
214223 gt_mask = None , valid_boxes = None ,
@@ -248,7 +257,7 @@ def _train_fn(*data_element):
248257 # print(f'get next takes: {s1-s0:.2f}s')
249258 (loss , (mask , iou , low_res_mask )), grads = grad_fn (
250259 input_dict ['image' ],
251- ms .mutable (point_and_label ),
260+ ms .mutable (point_and_label ), # mutable tuple to prevent duplicate graph compiling
252261 None , # box
253262 previous_low_mask ,
254263 gt_dict ['masks' ],
@@ -274,9 +283,9 @@ def _train_fn(*data_element):
274283
275284 # print(f'loss list', loss_list)
276285 t0 = time .time ()
277- grad_accum = grad_reducer ( grad_accum )
286+ grad_accum = grad_reducer_wrapper ( ms . mutable ( grad_accum )) # mutable tuple to prevent duplicate graph compiling
278287 if np .all (grad_finite_list ):
279- optimizer ( grad_accum )
288+ optimizer_wrapper ( ms . mutable ( grad_accum )) # mutable tuple to prevent duplicate graph compiling
280289 else :
281290 print (f'gradient overflow' )
282291 t1 = time .time ()
0 commit comments