2222from timm .layers import set_fast_norm
2323from timm .models import create_model , is_model , list_models
2424from timm .optim import create_optimizer_v2
25- from timm .utils import setup_default_logging , set_jit_fuser , decay_batch_step , check_batch_size_retry
25+ from timm .utils import setup_default_logging , set_jit_fuser , decay_batch_step , check_batch_size_retry , ParseKwargs
2626
2727has_apex = False
2828try :
108108 help = 'Enable gradient checkpointing through model blocks/stages' )
109109parser .add_argument ('--amp' , action = 'store_true' , default = False ,
110110 help = 'use PyTorch Native AMP for mixed precision training. Overrides --precision arg.' )
111+ parser .add_argument ('--amp-dtype' , default = 'float16' , type = str ,
112+ help = 'lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.' )
111113parser .add_argument ('--precision' , default = 'float32' , type = str ,
112114 help = 'Numeric precision. One of (amp, float32, float16, bfloat16, tf32)' )
113115parser .add_argument ('--fuser' , default = '' , type = str ,
114116 help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
115117parser .add_argument ('--fast-norm' , default = False , action = 'store_true' ,
116118 help = 'enable experimental fast-norm' )
119+ parser .add_argument ('--model-kwargs' , nargs = '*' , default = {}, action = ParseKwargs )
117120
118121# codegen (model compilation) options
119122scripting_group = parser .add_mutually_exclusive_group ()
124127scripting_group .add_argument ('--aot-autograd' , default = False , action = 'store_true' ,
125128 help = "Enable AOT Autograd optimization." )
126129
127-
128130# train optimizer parameters
129131parser .add_argument ('--opt' , default = 'sgd' , type = str , metavar = 'OPTIMIZER' ,
130132 help = 'Optimizer (default: "sgd"' )
@@ -168,19 +170,21 @@ def count_params(model: nn.Module):
168170
169171
170172def resolve_precision (precision : str ):
171- assert precision in ('amp' , 'float16' , 'bfloat16' , 'float32' )
172- use_amp = False
173+ assert precision in ('amp' , 'amp_bfloat16' , ' float16' , 'bfloat16' , 'float32' )
174+ amp_dtype = None # amp disabled
173175 model_dtype = torch .float32
174176 data_dtype = torch .float32
175177 if precision == 'amp' :
176- use_amp = True
178+ amp_dtype = torch .float16
179+ elif precision == 'amp_bfloat16' :
180+ amp_dtype = torch .bfloat16
177181 elif precision == 'float16' :
178182 model_dtype = torch .float16
179183 data_dtype = torch .float16
180184 elif precision == 'bfloat16' :
181185 model_dtype = torch .bfloat16
182186 data_dtype = torch .bfloat16
183- return use_amp , model_dtype , data_dtype
187+ return amp_dtype , model_dtype , data_dtype
184188
185189
186190def profile_deepspeed (model , input_size = (3 , 224 , 224 ), batch_size = 1 , detailed = False ):
@@ -228,9 +232,12 @@ def __init__(
228232 self .model_name = model_name
229233 self .detail = detail
230234 self .device = device
231- self .use_amp , self .model_dtype , self .data_dtype = resolve_precision (precision )
235+ self .amp_dtype , self .model_dtype , self .data_dtype = resolve_precision (precision )
232236 self .channels_last = kwargs .pop ('channels_last' , False )
233- self .amp_autocast = partial (torch .cuda .amp .autocast , dtype = torch .float16 ) if self .use_amp else suppress
237+ if self .amp_dtype is not None :
238+ self .amp_autocast = partial (torch .cuda .amp .autocast , dtype = self .amp_dtype )
239+ else :
240+ self .amp_autocast = suppress
234241
235242 if fuser :
236243 set_jit_fuser (fuser )
@@ -243,6 +250,7 @@ def __init__(
243250 drop_rate = kwargs .pop ('drop' , 0. ),
244251 drop_path_rate = kwargs .pop ('drop_path' , None ),
245252 drop_block_rate = kwargs .pop ('drop_block' , None ),
253+ ** kwargs .pop ('model_kwargs' , {}),
246254 )
247255 self .model .to (
248256 device = self .device ,
@@ -560,7 +568,7 @@ def _try_run(
560568def benchmark (args ):
561569 if args .amp :
562570 _logger .warning ("Overriding precision to 'amp' since --amp flag set." )
563- args .precision = 'amp'
571+ args .precision = 'amp' if args . amp_dtype == 'float16' else '_' . join ([ 'amp' , args . amp_dtype ])
564572 _logger .info (f'Benchmarking in { args .precision } precision. '
565573 f'{ "NHWC" if args .channels_last else "NCHW" } layout. '
566574 f'torchscript { "enabled" if args .torchscript else "disabled" } ' )
0 commit comments