1111import logging
1212import numpy as np
1313from collections import OrderedDict
14+ try :
15+ from apex import amp
16+ has_apex = True
17+ except ImportError :
18+ amp = None
19+ has_apex = False
1420
1521from torch import distributed as dist
1622
@@ -50,7 +56,7 @@ def __init__(
5056 self .max_history = max_history
5157 assert self .max_history >= 1
5258
53- def save_checkpoint (self , model , optimizer , args , epoch , model_ema = None , metric = None ):
59+ def save_checkpoint (self , model , optimizer , args , epoch , model_ema = None , metric = None , use_amp = False ):
5460 assert epoch >= 0
5561 worst_file = self .checkpoint_files [- 1 ] if self .checkpoint_files else None
5662 if (len (self .checkpoint_files ) < self .max_history
@@ -59,7 +65,7 @@ def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=
5965 self ._cleanup_checkpoints (1 )
6066 filename = '-' .join ([self .save_prefix , str (epoch )]) + self .extension
6167 save_path = os .path .join (self .checkpoint_dir , filename )
62- self ._save (save_path , model , optimizer , args , epoch , model_ema , metric )
68+ self ._save (save_path , model , optimizer , args , epoch , model_ema , metric , use_amp )
6369 self .checkpoint_files .append ((save_path , metric ))
6470 self .checkpoint_files = sorted (
6571 self .checkpoint_files , key = lambda x : x [1 ],
@@ -77,7 +83,7 @@ def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=
7783
7884 return (None , None ) if self .best_metric is None else (self .best_metric , self .best_epoch )
7985
80- def _save (self , save_path , model , optimizer , args , epoch , model_ema = None , metric = None ):
86+ def _save (self , save_path , model , optimizer , args , epoch , model_ema = None , metric = None , use_amp = False ):
8187 save_state = {
8288 'epoch' : epoch ,
8389 'arch' : args .model ,
@@ -86,6 +92,8 @@ def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric
8692 'args' : args ,
8793 'version' : 2 , # version < 2 increments epoch before save
8894 }
95+ if use_amp and 'state_dict' in amp .__dict__ :
96+ save_state ['amp' ] = amp .state_dict ()
8997 if model_ema is not None :
9098 save_state ['state_dict_ema' ] = get_state_dict (model_ema )
9199 if metric is not None :
@@ -106,11 +114,11 @@ def _cleanup_checkpoints(self, trim=0):
106114 logging .error ("Exception '{}' while deleting checkpoint" .format (e ))
107115 self .checkpoint_files = self .checkpoint_files [:delete_index ]
108116
109- def save_recovery (self , model , optimizer , args , epoch , model_ema = None , batch_idx = 0 ):
117+ def save_recovery (self , model , optimizer , args , epoch , model_ema = None , use_amp = False , batch_idx = 0 ):
110118 assert epoch >= 0
111119 filename = '-' .join ([self .recovery_prefix , str (epoch ), str (batch_idx )]) + self .extension
112120 save_path = os .path .join (self .recovery_dir , filename )
113- self ._save (save_path , model , optimizer , args , epoch , model_ema )
121+ self ._save (save_path , model , optimizer , args , epoch , model_ema , use_amp = use_amp )
114122 if os .path .exists (self .last_recovery_file ):
115123 try :
116124 logging .debug ("Cleaning recovery: {}" .format (self .last_recovery_file ))
0 commit comments