@@ -37,20 +37,67 @@ def unwrap_model(model):
3737 return model .module if hasattr (model , 'module' ) else model
3838
3939
40- def get_state_dict (model ):
41- return unwrap_model (model ).state_dict ()
40+ def get_state_dict (model , unwrap_fn = unwrap_model ):
41+ return unwrap_fn (model ).state_dict ()
42+
43+
44+ class ApexScaler :
45+ state_dict_key = "amp"
46+
47+ def __call__ (self , loss , optimizer ):
48+ with amp .scale_loss (loss , optimizer ) as scaled_loss :
49+ scaled_loss .backward ()
50+ optimizer .step ()
51+
52+ def state_dict (self ):
53+ if 'state_dict' in amp .__dict__ :
54+ return amp .state_dict ()
55+
56+ def load_state_dict (self , state_dict ):
57+ if 'load_state_dict' in amp .__dict__ :
58+ amp .load_state_dict (state_dict )
59+
60+
61+ class NativeScaler :
62+ state_dict_key = "amp_scaler"
63+
64+ def __init__ (self ):
65+ self ._scaler = torch .cuda .amp .GradScaler ()
66+
67+ def __call__ (self , loss , optimizer ):
68+ self ._scaler .scale (loss ).backward ()
69+ self ._scaler .step (optimizer )
70+ self ._scaler .update ()
71+
72+ def state_dict (self ):
73+ return self ._scaler .state_dict ()
74+
75+ def load_state_dict (self , state_dict ):
76+ self ._scaler .load_state_dict (state_dict )
4277
4378
4479class CheckpointSaver :
4580 def __init__ (
4681 self ,
82+ model ,
83+ optimizer ,
84+ args = None ,
85+ model_ema = None ,
86+ amp_scaler = None ,
4787 checkpoint_prefix = 'checkpoint' ,
4888 recovery_prefix = 'recovery' ,
4989 checkpoint_dir = '' ,
5090 recovery_dir = '' ,
5191 decreasing = False ,
5292 max_history = 10 ,
53- save_amp = False ):
93+ unwrap_fn = unwrap_model ):
94+
95+ # objects to save state_dicts of
96+ self .model = model
97+ self .optimizer = optimizer
98+ self .args = args
99+ self .model_ema = model_ema
100+ self .amp_scaler = amp_scaler
54101
55102 # state
56103 self .checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
@@ -68,14 +115,14 @@ def __init__(
68115 self .decreasing = decreasing # a lower metric is better if True
69116 self .cmp = operator .lt if decreasing else operator .gt # True if lhs better than rhs
70117 self .max_history = max_history
71- self .save_apex_amp = save_amp # save APEX amp state
118+ self .unwrap_fn = unwrap_fn
72119 assert self .max_history >= 1
73120
74- def save_checkpoint (self , model , optimizer , args , epoch , model_ema = None , metric = None ):
121+ def save_checkpoint (self , epoch , metric = None ):
75122 assert epoch >= 0
76123 tmp_save_path = os .path .join (self .checkpoint_dir , 'tmp' + self .extension )
77124 last_save_path = os .path .join (self .checkpoint_dir , 'last' + self .extension )
78- self ._save (tmp_save_path , model , optimizer , args , epoch , model_ema , metric )
125+ self ._save (tmp_save_path , epoch , metric )
79126 if os .path .exists (last_save_path ):
80127 os .unlink (last_save_path ) # required for Windows support.
81128 os .rename (tmp_save_path , last_save_path )
@@ -107,19 +154,21 @@ def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=
107154
108155 return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
109156
110- def _save (self , save_path , model , optimizer , args , epoch , model_ema = None , metric = None ):
157+ def _save (self , save_path , epoch , metric = None ):
111158 save_state = {
112159 'epoch' : epoch ,
113- 'arch' : args .model ,
114- 'state_dict' : get_state_dict (model ),
115- 'optimizer' : optimizer .state_dict (),
116- 'args' : args ,
160+ 'arch' : type (self .model ).__name__ .lower (),
161+ 'state_dict' : get_state_dict (self .model , self .unwrap_fn ),
162+ 'optimizer' : self .optimizer .state_dict (),
117163 'version' : 2 , # version < 2 increments epoch before save
118164 }
119- if self .save_apex_amp and 'state_dict' in amp .__dict__ :
120- save_state ['amp' ] = amp .state_dict ()
121- if model_ema is not None :
122- save_state ['state_dict_ema' ] = get_state_dict (model_ema )
165+ if self .args is not None :
166+ save_state ['arch' ] = self .args .model
167+ save_state ['args' ] = self .args
168+ if self .amp_scaler is not None :
169+ save_state [self .amp_scaler .state_dict_key ] = self .amp_scaler .state_dict ()
170+ if self .model_ema is not None :
171+ save_state ['state_dict_ema' ] = get_state_dict (self .model_ema , self .unwrap_fn )
123172 if metric is not None :
124173 save_state ['metric' ] = metric
125174 torch .save (save_state , save_path )
@@ -138,11 +187,11 @@ def _cleanup_checkpoints(self, trim=0):
138187 _logger .error ("Exception '{}' while deleting checkpoint" .format (e ))
139188 self .checkpoint_files = self .checkpoint_files [:delete_index ]
140189
141- def save_recovery (self , model , optimizer , args , epoch , model_ema = None , batch_idx = 0 ):
190+ def save_recovery (self , epoch , batch_idx = 0 ):
142191 assert epoch >= 0
143192 filename = '-' .join ([self .recovery_prefix , str (epoch ), str (batch_idx )]) + self .extension
144193 save_path = os .path .join (self .recovery_dir , filename )
145- self ._save (save_path , model , optimizer , args , epoch , model_ema )
194+ self ._save (save_path , epoch )
146195 if os .path .exists (self .last_recovery_file ):
147196 try :
148197 _logger .debug ("Cleaning recovery: {}" .format (self .last_recovery_file ))
@@ -336,3 +385,16 @@ def add_bool_arg(parser, name, default=False, help=''):
336385 group .add_argument ('--' + name , dest = dest_name , action = 'store_true' , help = help )
337386 group .add_argument ('--no-' + name , dest = dest_name , action = 'store_false' , help = help )
338387 parser .set_defaults (** {dest_name : default })
388+
389+
390+ def set_jit_legacy ():
391+ """ Set JIT executor to legacy w/ support for op fusion
392+ This is hopefully a temporary need in 1.5/1.5.1/1.6 to restore performance due to changes
393+ in the JIT exectutor. These API are not supported so could change.
394+ """
395+ #
396+ assert hasattr (torch ._C , '_jit_set_profiling_executor' ), "Old JIT behavior doesn't exist!"
397+ torch ._C ._jit_set_profiling_executor (False )
398+ torch ._C ._jit_set_profiling_mode (False )
399+ torch ._C ._jit_override_can_fuse_on_gpu (True )
400+ #torch._C._jit_set_texpr_fuser_enabled(True)
0 commit comments