Skip to content

Commit 84f697b

Browse files
author
Mark-ZhouWX
committed
add point finetune: fix bug of all_finite for distributed training
1 parent 37588ed commit 84f697b

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

official/cv/segment-anything/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ omegaconf==2.0.0
66
# Optional. for preprocess medical ct and mr dataset
77
# connected-components-3d
88
# SimpleITK
9+
# scikit-image

official/cv/segment-anything/scripts/preprocess_CT_MR_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
# remove label ids
7979
for remove_label_id in remove_label_ids:
8080
gt_data_ori[gt_data_ori == remove_label_id] = 0
81-
all_labels = np.unique(gt_data_ori).sort()[1:]
81+
all_labels = np.sort(np.unique(gt_data_ori))[1:]
8282

8383
# remove obj with more than one connected area
8484
# for l in all_labels:

official/cv/segment-anything/segment_anything/utils/model_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,9 @@ def _train_fn(*data_element):
284284
# print(f'loss list', loss_list)
285285
t0 = time.time()
286286
grad_accum = grad_reducer_wrapper(ms.mutable(grad_accum)) # mutable tuple to prevent duplicate graph compiling
287-
if np.all(grad_finite_list):
287+
# all finite should be after grad reduce for multi node
288+
grad_accum_finite = all_finite(grad_accum)
289+
if grad_accum_finite:
288290
optimizer_wrapper(ms.mutable(grad_accum)) # mutable tuple to prevent duplicate graph compiling
289291
else:
290292
print(f'gradient overflow')

0 commit comments

Comments
 (0)