Skip to content

Commit 9db9993

Browse files
author
Mark-ZhouWX
committed
fix bug of callback hack update
1 parent 57fb8e5 commit 9db9993

File tree

3 files changed

+141
-2
lines changed

3 files changed

+141
-2
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#---------------------------------------------
2+
# Part 1: system basic config setting
3+
distributed: False
4+
device: Ascend
5+
mode: 0 # 0: graph, 1: pynative
6+
work_root: &work_root ./work_dir/
7+
log_level: info
8+
9+
10+
# ---------------------------------------------
11+
# Part2: module setting
12+
loss_manager:
13+
# type: fixed # dynamic or
14+
# scale_sense: 1024
15+
loss_scaler:
16+
type: dynamic
17+
grad_clip: False
18+
19+
20+
optimizer:
21+
type: segment_anything.optim.optimizer.AdamW
22+
weight_decay: 1e-4
23+
group_param:
24+
25+
lr_scheduler:
26+
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
27+
learning_rate: 8e-6
28+
warmup_steps: 250
29+
decay_steps: [ 60000, 86666 ]
30+
decay_factor: 10
31+
32+
33+
network:
34+
model:
35+
type: vit_b
36+
checkpoint: ./models/sam_vit_b-35e4849c.ckpt
37+
freeze:
38+
image_encoder: True
39+
prompt_encoder: True
40+
41+
loss:
42+
type: segment_anything.modeling.loss.SAMLoss
43+
44+
45+
train_loader:
46+
dataset:
47+
type: segment_anything.dataset.dataset.COCODataset
48+
data_dir: ./datasets/coco2017/train2017
49+
annotation_path: ./datasets/coco2017/annotations/instances_train2017.json
50+
transform_pipeline:
51+
- type: segment_anything.dataset.transform.ImageResizeAndPad
52+
target_size: 1024
53+
- type: segment_anything.dataset.transform.ImageNorm
54+
hwc2chw: True
55+
- type: segment_anything.dataset.transform.LabelPad
56+
gt_size: 20
57+
output_column: ['image', 'masks', 'boxes', 'valid_boxes']
58+
59+
model_column: ['image', 'boxes'] # columns for model cell input
60+
loss_column: ['masks', 'valid_boxes'] # columns for loss function input
61+
62+
shuffle: True
63+
batch_size: 1
64+
epoch_size: 50
65+
drop_remainder: True
66+
num_workers: 2
67+
max_rowsize: 24 # 24M space for dataloader
68+
69+
70+
callback:
71+
- type: segment_anything.utils.callbacks.Profiler
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#---------------------------------------------
2+
# Part 1: system basic config setting
3+
distributed: False
4+
device: Ascend
5+
mode: 0 # 0: graph, 1: pynative
6+
work_root: &work_root ./work_dir/
7+
log_level: info
8+
9+
10+
# ---------------------------------------------
11+
# Part2: module setting
12+
loss_manager:
13+
# type: fixed # dynamic or
14+
# scale_sense: 1024
15+
loss_scaler:
16+
type: dynamic
17+
grad_clip: False
18+
19+
20+
optimizer:
21+
type: segment_anything.optim.optimizer.AdamW
22+
weight_decay: 1e-4
23+
group_param:
24+
25+
lr_scheduler:
26+
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
27+
learning_rate: 8e-6
28+
warmup_steps: 250
29+
decay_steps: [ 60000, 86666 ]
30+
decay_factor: 10
31+
32+
33+
network:
34+
model:
35+
type: vit_b
36+
checkpoint: ./models/sam_vit_b-35e4849c.ckpt
37+
freeze:
38+
image_encoder: True
39+
prompt_encoder: True
40+
41+
loss:
42+
type: segment_anything.modeling.loss.SAMLoss
43+
44+
45+
train_loader:
46+
dataset:
47+
type: segment_anything.dataset.dataset.FLAREDataset
48+
data_dir: ./datasets/FLARE22Train_processed/train/
49+
transform_pipeline:
50+
- type: segment_anything.dataset.transform.BinaryMaskFromInstanceSeg
51+
- type: segment_anything.dataset.transform.BoxFormMask
52+
- type: segment_anything.dataset.transform.LabelPad
53+
gt_size: 20
54+
output_column: ['image', 'masks', 'boxes', 'valid_boxes' ]
55+
56+
model_column: ['image', 'boxes' ] # columns for model cell input
57+
loss_column: ['masks', 'valid_boxes' ] # columns for loss function input
58+
59+
shuffle: True
60+
batch_size: 1
61+
epoch_size: 50
62+
drop_remainder: True
63+
num_workers: 2
64+
max_rowsize: 64 # 24M space for dataloader
65+
66+
67+
callback:
68+
- type: segment_anything.utils.callbacks.Profiler

research/segment-anything/segment_anything/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def update_rank_to_dataloader_config(rank_id, rank_size, args_train_loader, args
105105
# a hack implementation to update runtime-defined setting in callback args
106106
if arg_callback is not None:
107107
for cb in arg_callback:
108-
if cb.type == 'EvalWhileTrain':
108+
if cb.type.endswith('EvalWhileTrain'):
109109
cb.data_loader = args_eval_loader
110110

111111

@@ -130,7 +130,7 @@ def set_directory_and_log(main_device, rank_id, rank_size, work_root, log_level,
130130
# a hack implementation to update runtime-defined setting in callback args
131131
if args_callback is not None:
132132
for cb in args_callback:
133-
if cb.type == 'SaveCkpt':
133+
if cb.type.endswith('SaveCkpt'):
134134
hack_list = {'save_dir': save_dir, 'main_device': main_device}
135135
cb.update(hack_list)
136136
return save_dir

0 commit comments

Comments
 (0)