@@ -61,8 +61,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
6161 opt_split = opt_lower .split ('_' )
6262 opt_lower = opt_split [- 1 ]
6363 if opt_lower == 'sgd' or opt_lower == 'nesterov' :
64+ del opt_args ['eps' ]
6465 optimizer = optim .SGD (parameters , momentum = args .momentum , nesterov = True , ** opt_args )
6566 elif opt_lower == 'momentum' :
67+ del opt_args ['eps' ]
6668 optimizer = optim .SGD (parameters , momentum = args .momentum , nesterov = False , ** opt_args )
6769 elif opt_lower == 'adam' :
6870 optimizer = optim .Adam (parameters , ** opt_args )
@@ -93,8 +95,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
9395 elif opt_lower == 'nvnovograd' :
9496 optimizer = NvNovoGrad (parameters , ** opt_args )
9597 elif opt_lower == 'fusedsgd' :
98+ del opt_args ['eps' ]
9699 optimizer = FusedSGD (parameters , momentum = args .momentum , nesterov = True , ** opt_args )
97100 elif opt_lower == 'fusedmomentum' :
101+ del opt_args ['eps' ]
98102 optimizer = FusedSGD (parameters , momentum = args .momentum , nesterov = False , ** opt_args )
99103 elif opt_lower == 'fusedadam' :
100104 optimizer = FusedAdam (parameters , adam_w_mode = False , ** opt_args )
0 commit comments