File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -514,8 +514,14 @@ def main():
514514 if utils .is_primary (args ):
515515 _logger .info ('Using NVIDIA APEX AMP. Training in mixed precision.' )
516516 elif use_amp == 'native' :
517- amp_autocast = partial (torch .autocast , device_type = device .type , dtype = amp_dtype )
518- if device .type == 'cuda' :
517+ try :
518+ amp_autocast = partial (torch .autocast , device_type = device .type , dtype = amp_dtype )
519+ except (AttributeError , TypeError ):
520+ # fallback to CUDA only AMP for PyTorch < 1.10
521+ assert device .type == 'cuda'
522+ amp_autocast = torch .cuda .amp .autocast
523+ if device .type == 'cuda' and amp_dtype == torch .float16 :
524+ # loss scaler only used for float16 (half) dtype, bfloat16 does not need it
519525 loss_scaler = NativeScaler ()
520526 if utils .is_primary (args ):
521527 _logger .info ('Using native Torch AMP. Training in mixed precision.' )
You can’t perform that action at this time.
0 commit comments