33import torch
44from torch .optim import Optimizer
55from math import pi , cos
6-
7-
8- def flat_and_anneal_lr_scheduler (optimizer , total_iters , \
9- warmup_iters = 0 , warmup_factor = 0.1 , warmup_method = 'linear' , \
10- anneal_point = 0.72 , anneal_method = 'cosine' , target_lr_factor = 0 , \
11- poly_power = 1.0 , step_gamma = 0.1 , steps = [0.5 , 0.75 ], \
6+ import warnings
7+
8+
9+ def flat_and_anneal_lr_scheduler (
10+ optimizer ,
11+ total_iters ,
12+ warmup_iters = 0 ,
13+ warmup_factor = 0.1 ,
14+ warmup_method = "linear" ,
15+ anneal_point = 0.72 ,
16+ anneal_method = "cosine" ,
17+ target_lr_factor = 0 ,
18+ poly_power = 1.0 ,
19+ step_gamma = 0.1 ,
20+ steps = [2 / 3.0 , 8 / 9.0 ],
21+ return_function = False ,
1222):
13- """https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py
23+ """Ref: https://github.com/fastai/fastai/blob/master/fastai/callbacks/flat_cos_anneal.py.
24+
1425 warmup_initial_lr = warmup_factor * base_lr
1526 target_lr = base_lr * target_lr_factor
27+ total_iters: cycle length; set to max_iter to get a one cycle schedule.
1628 """
1729 if warmup_method not in ("constant" , "linear" ):
18- raise ValueError ("Only 'constant' or 'linear' warmup_method accepted,"
19- "got {}" .format (warmup_method ))
30+ raise ValueError ("Only 'constant' or 'linear' warmup_method accepted," "got {}" .format (warmup_method ))
2031
21- if anneal_method not in ("cosine" , "linear" , "poly" , "exp" , "step" ):
22- raise ValueError ("Only 'cosine', 'linear', 'poly', 'exp' or 'step' anneal_method accepted,"
23- "got {}" .format (anneal_method ))
32+ if anneal_method not in ("cosine" , "linear" , "poly" , "exp" , "step" , "none" ):
33+ raise ValueError (
34+ "Only 'cosine', 'linear', 'poly', 'exp', 'step' or 'none' anneal_method accepted,"
35+ "got {}" .format (anneal_method )
36+ )
2437
25- if anneal_method == ' step' :
38+ if anneal_method == " step" :
2639 if any ([_step < warmup_iters / total_iters or _step > 1 for _step in steps ]):
27- raise ValueError ("error in steps: {}. warmup_iters: {} total_iters: {}."
28- "steps should be in ({},1)" .format (steps , warmup_iters , total_iters , \
29- warmup_iters / total_iters ))
30- if steps != sorted (steps ):
31- raise ValueError ("steps {} is not in ascending order." )
32- print ("ignore anneal_point when using step anneal_method" )
40+ raise ValueError (
41+ "error in steps: {}. warmup_iters: {} total_iters: {}."
42+ "steps should be in ({},1)" .format (steps , warmup_iters , total_iters , warmup_iters / total_iters )
43+ )
44+ if list (steps ) != sorted (steps ):
45+ raise ValueError ("steps {} is not in ascending order." .format (steps ))
46+ warnings .warn ("ignore anneal_point when using step anneal_method" )
3347 anneal_start = steps [0 ] * total_iters
3448 else :
3549 if anneal_point > 1 or anneal_point < 0 :
@@ -38,91 +52,108 @@ def flat_and_anneal_lr_scheduler(optimizer, total_iters, \
3852
3953 def f (x ): # x is the iter in lr scheduler, return the lr_factor
4054 # the final lr is warmup_factor * base_lr
55+ x = x % total_iters # cyclic
4156 if x < warmup_iters :
42- if warmup_method == ' linear' :
57+ if warmup_method == " linear" :
4358 alpha = float (x ) / warmup_iters
4459 return warmup_factor * (1 - alpha ) + alpha
45- elif warmup_method == ' constant' :
60+ elif warmup_method == " constant" :
4661 return warmup_factor
4762 elif x >= anneal_start :
48- if anneal_method == ' step' :
63+ if anneal_method == " step" :
4964 # ignore anneal_point and target_lr_factor
5065 milestones = [_step * total_iters for _step in steps ]
51- lr_factor = step_gamma ** bisect_right (milestones , float (x ))
52- elif anneal_method == ' cosine' :
66+ lr_factor = step_gamma ** bisect_right (milestones , float (x ))
67+ elif anneal_method == " cosine" :
5368 # slow --> fast --> slow
54- lr_factor = target_lr_factor + 0.5 * (1 - target_lr_factor ) * \
55- (1 + cos (pi * ((float (x ) - anneal_start ) / (total_iters - anneal_start ))))
56- elif anneal_method == 'linear' :
69+ lr_factor = target_lr_factor + 0.5 * (1 - target_lr_factor ) * (
70+ 1 + cos (pi * ((float (x ) - anneal_start ) / (total_iters - anneal_start )))
71+ )
72+ elif anneal_method == "linear" :
5773 # (y-m) / (B-x) = (1-m) / (B-A)
58- lr_factor = target_lr_factor + (1 - target_lr_factor ) * \
59- (total_iters - float (x )) / (total_iters - anneal_start )
60- elif anneal_method == 'poly' :
74+ lr_factor = target_lr_factor + (1 - target_lr_factor ) * (total_iters - float (x )) / (
75+ total_iters - anneal_start
76+ )
77+ elif anneal_method == "poly" :
6178 # slow --> fast if poly_power < 1
6279 # fast --> slow if poly_power > 1
6380 # when poly_power == 1.0, it is the same with linear
64- lr_factor = target_lr_factor + (1 - target_lr_factor ) * \
65- ((total_iters - float (x )) / (total_iters - anneal_start )) ** poly_power
66- elif anneal_method == 'exp' :
81+ lr_factor = (
82+ target_lr_factor
83+ + (1 - target_lr_factor ) * ((total_iters - float (x )) / (total_iters - anneal_start )) ** poly_power
84+ )
85+ elif anneal_method == "exp" :
6786 # fast --> slow
6887 # do not decay too much, especially if lr_end == 0, lr will be
6988 # 0 at anneal iter, so we should avoid that
7089 _target_lr_factor = max (target_lr_factor , 5e-3 )
71- lr_factor = _target_lr_factor ** ( \
72- (float (x ) - anneal_start ) / (total_iters - anneal_start ))
90+ lr_factor = _target_lr_factor ** ((float (x ) - anneal_start ) / (total_iters - anneal_start ))
7391 else :
7492 lr_factor = 1
7593 return lr_factor
7694 else : # warmup_iter <= x < anneal_start_iter
7795 return 1
7896
79- return torch .optim .lr_scheduler .LambdaLR (optimizer , f )
97+ if return_function :
98+ return torch .optim .lr_scheduler .LambdaLR (optimizer , f ), f
99+ else :
100+ return torch .optim .lr_scheduler .LambdaLR (optimizer , f )
80101
81102
82103def test_flat_and_anneal ():
83104 from mmcv import Config
84105 import numpy as np
85- model = resnet18 ()
86106
87- optimizer_cfg = dict (type = 'Adam' , lr = 1e-4 , weight_decay = 0 )
107+ model = resnet18 ()
108+ base_lr = 1e-4
109+ optimizer_cfg = dict (type = "Adam" , lr = base_lr , weight_decay = 0 )
88110 optimizer = obj_from_dict (optimizer_cfg , torch .optim , dict (params = model .parameters ()))
89111
90112 # learning policy
91113 total_epochs = 80
92114 epoch_len = 500
93- total_iters = epoch_len * total_epochs
115+ total_iters = epoch_len * total_epochs // 2
94116 # poly, step, linear, exp, cosine
95117 lr_cfg = Config (
96118 dict (
97- anneal_method = 'cosine' ,
98- warmup_method = 'linear' ,
119+ # anneal_method="cosine",
120+ # anneal_method="linear",
121+ # anneal_method="poly",
122+ # anneal_method="exp",
123+ anneal_method = "step" ,
124+ warmup_method = "linear" ,
99125 step_gamma = 0.1 ,
100126 warmup_factor = 0.1 ,
101127 warmup_iters = 800 ,
102128 poly_power = 5 ,
103- target_lr_factor = 0. ,
129+ target_lr_factor = 0.0 ,
104130 steps = [0.5 , 0.75 , 0.9 ],
105131 anneal_point = 0.72 ,
106- ))
132+ )
133+ )
107134
108135 # scheduler = build_scheduler(lr_config, optimizer, epoch_length)
109136 scheduler = flat_and_anneal_lr_scheduler (
110- optimizer = optimizer , total_iters = total_iters , \
111- warmup_method = lr_cfg .warmup_method , warmup_factor = lr_cfg .warmup_factor , \
112- warmup_iters = lr_cfg .warmup_iters , \
113- anneal_method = lr_cfg .anneal_method , anneal_point = lr_cfg .anneal_point , \
114- target_lr_factor = lr_cfg .target_lr_factor , \
115- poly_power = lr_cfg .poly_power , \
116- step_gamma = lr_cfg .step_gamma , steps = lr_cfg .steps , \
137+ optimizer = optimizer ,
138+ total_iters = total_iters ,
139+ warmup_method = lr_cfg .warmup_method ,
140+ warmup_factor = lr_cfg .warmup_factor ,
141+ warmup_iters = lr_cfg .warmup_iters ,
142+ anneal_method = lr_cfg .anneal_method ,
143+ anneal_point = lr_cfg .anneal_point ,
144+ target_lr_factor = lr_cfg .target_lr_factor ,
145+ poly_power = lr_cfg .poly_power ,
146+ step_gamma = lr_cfg .step_gamma ,
147+ steps = lr_cfg .steps ,
117148 )
118- print (' start lr: {}' .format (scheduler .get_lr ()))
149+ print (" start lr: {}" .format (scheduler .get_lr ()))
119150 steps = []
120151 lrs = []
121152
122153 epoch_lrs = []
123154 global_step = 0
124155
125- start_epoch = 20
156+ start_epoch = 0
126157 for epoch in range (start_epoch ):
127158 for batch in range (epoch_len ):
128159 scheduler .step () # when no state_dict availble
@@ -133,41 +164,38 @@ def test_flat_and_anneal():
133164 # scheduler.step(epoch)
134165 # print(type(scheduler.get_lr()[0]))
135166 # import pdb;pdb.set_trace()
136- epoch_lrs .append ([epoch ,
137- scheduler .get_lr ()[0 ]]) # only get the first lr (maybe a group of lrs)
167+ epoch_lrs .append ([epoch , scheduler .get_lr ()[0 ]]) # only get the first lr (maybe a group of lrs)
138168 for batch in range (epoch_len ):
139169 # if global_step < lr_config['warmup_iters']:
140170 # scheduler.step(global_step)
141171 cur_lr = scheduler .get_lr ()[0 ]
142172 if global_step == 0 or (len (lrs ) >= 1 and cur_lr != lrs [- 1 ]):
143- print ('epoch {}, batch: {}, global_step:{} lr: {}' .format (
144- epoch , batch , global_step , cur_lr ))
173+ print ("epoch {}, batch: {}, global_step:{} lr: {}" .format (epoch , batch , global_step , cur_lr ))
145174 steps .append (global_step )
146175 lrs .append (cur_lr )
147176 global_step += 1
148177 scheduler .step () # usually after optimizer.step()
149178 # print(epoch_lrs)
150179 # import pdb;pdb.set_trace()
151- epoch_lrs .append ([total_epochs , scheduler .get_lr ()[0 ]])
180+ # epoch_lrs.append([total_epochs, scheduler.get_lr()[0]])
152181
153182 epoch_lrs = np .asarray (epoch_lrs , dtype = np .float32 )
154183 for i in range (len (epoch_lrs )):
155- print (' {:02d} {}' .format (int (epoch_lrs [i ][0 ]), epoch_lrs [i ][1 ]))
184+ print (" {:02d} {}" .format (int (epoch_lrs [i ][0 ]), epoch_lrs [i ][1 ]))
156185
157- plt .figure (dpi = 200 )
158- plt .suptitle ('{}' .format (dict (lr_cfg )), size = 4 )
186+ plt .figure (dpi = 100 )
187+ plt .suptitle ("{}" .format (dict (lr_cfg )), size = 4 )
159188 plt .subplot (1 , 2 , 1 )
160- plt .plot (steps , lrs )
189+ plt .plot (steps , lrs , "-." )
161190 # plt.show()
162191 plt .subplot (1 , 2 , 2 )
163192 # print(epoch_lrs.dtype)
164- plt .plot (epoch_lrs [:, 0 ], epoch_lrs [:, 1 ])
193+ plt .plot (epoch_lrs [:, 0 ], epoch_lrs [:, 1 ], "-." )
165194 plt .show ()
166195
167196
168197if __name__ == "__main__" :
169198 from mmcv .runner import obj_from_dict
170- import sys
171199 import os .path as osp
172200 from torchvision .models import resnet18
173201 import matplotlib .pyplot as plt
0 commit comments