Skip to content

Commit a1929ce

Browse files
author
Mark-ZhouWX
committed
add point finetune: lr schedule from dynamic to pre-computed
1 parent 53a779d commit a1929ce

File tree

5 files changed

+40
-8
lines changed

5 files changed

+40
-8
lines changed

official/cv/segment-anything/configs/sa1b_point_finetune.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ optimizer:
2323
group_param:
2424

2525
lr_scheduler:
26-
type: segment_anything.optim.scheduler.SAMDynamicDecayLR
26+
type: segment_anything.optim.scheduler.sam_dynamic_decay_lr
2727
learning_rate: 8e-6
2828
warmup_steps: 250
2929
decay_steps: [ 60000, 86666 ]

official/cv/segment-anything/segment_anything/optim/optimizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
def create_optimizer(
1111
params,
1212
args,
13+
step_per_epoch,
14+
epoch_size
1315
):
1416
r"""Creates optimizer by name.
1517
@@ -25,15 +27,16 @@ def create_optimizer(
2527
Returns:
2628
Optimizer object
2729
"""
28-
optimizer = OPTIMIZER_REGISTRY.instantiate(**args, params=params)
30+
optimizer = OPTIMIZER_REGISTRY.instantiate(**args, params=params,
31+
step_per_epoch=step_per_epoch, epoch_size=epoch_size)
2932
return optimizer
3033

3134

3235
@OPTIMIZER_REGISTRY.registry_module()
3336
class AdamW(nn.optim.Adam):
34-
def __init__(self, params: List, lr_scheduler, group_param, **kwargs):
37+
def __init__(self, params: List, lr_scheduler, group_param, step_per_epoch, epoch_size, **kwargs):
3538
if group_param is None:
3639
group_param = dict()
3740
params = create_group_param(params, **group_param)
38-
lr_scheduler_inst = create_lr_scheduler(lr_scheduler)
41+
lr_scheduler_inst = create_lr_scheduler(lr_scheduler, step_per_epoch=step_per_epoch, epoch_size=epoch_size)
3942
super().__init__(params, lr_scheduler_inst, **kwargs)

official/cv/segment-anything/segment_anything/optim/scheduler.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,42 @@
77
from segment_anything.utils.registry import LR_SCHEDULER_REGISTRY
88

99

10-
def create_lr_scheduler(args: Dict):
10+
def create_lr_scheduler(args, step_per_epoch, epoch_size):
1111
"""
1212
instantiate learning rate scheduler class
1313
"""
14+
if args.type.endswith('sam_dynamic_decay_lr'):
15+
return sam_dynamic_decay_lr(learning_rate=args.learning_rate,
16+
warmup_steps=args.warmup_steps,
17+
decay_steps=args.decay_steps,
18+
decay_factor=args.decay_factor,
19+
step_per_epoch=step_per_epoch,
20+
epoch_size=epoch_size,
21+
)
1422
scheduler = LR_SCHEDULER_REGISTRY.instantiate(**args)
1523
return scheduler
1624

25+
def sam_dynamic_decay_lr(learning_rate, warmup_steps, decay_steps, decay_factor, step_per_epoch, epoch_size):
26+
def lr_factor(step):
27+
if step < warmup_steps:
28+
return step / float(warmup_steps)
29+
elif step < decay_steps[0]:
30+
return 1.0
31+
elif step < decay_steps[1]:
32+
return 1.0 / decay_factor
33+
else:
34+
return 1.0 / (decay_factor**2)
35+
total_step = step_per_epoch * epoch_size
36+
lr_list = []
37+
for i in range(total_step):
38+
step = i + 1
39+
lr = learning_rate * lr_factor(step)
40+
lr_list.append(lr)
41+
42+
return lr_list
43+
1744

18-
@LR_SCHEDULER_REGISTRY.registry_module()
45+
# @LR_SCHEDULER_REGISTRY.registry_module()
1946
class SAMDynamicDecayLR(LearningRateSchedule):
2047
def __init__(self,
2148
learning_rate: float,

official/cv/segment-anything/segment_anything/utils/callbacks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def on_train_step_end(self, run_context: RunContext):
130130
self.accumulate_loss += loss
131131

132132
if cur_step % self.log_interval == 0:
133-
lr = cb_params.network.optimizer.learning_rate(cur_step)
133+
lr = cb_params.network.optimizer.learning_rate.learning_rate[cur_step]
134134
smooth_loss = self.accumulate_loss / self.log_interval
135135

136136
step_cost = time.time() - self.step_start_time

official/cv/segment-anything/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def main(args) -> None:
3737
network = amp.auto_mixed_precision(network, args.get('amp_level', 'O0'))
3838

3939
# Step3: create optimizer, including learning rate scheduler and group parameter settings
40-
optimizer = create_optimizer(params=network.trainable_params(), args=args.optimizer)
40+
optimizer = create_optimizer(params=network.trainable_params(), args=args.optimizer,
41+
step_per_epoch=train_dataloader.get_dataset_size(),
42+
epoch_size=args.train_loader.epoch_size)
4143

4244
# Step4: wrap model and optimizer for training
4345
with_loss_model = NetWithLossWrapper(network, loss_fn=loss_fn,

0 commit comments

Comments
 (0)