@@ -388,20 +388,16 @@ def wraps_optimizer(cls):
388388 HvdOptimizer
389389 '''
390390 class HvdOptimizer (cls , optimizer .Optimizer ):
391- def __init__ (self , * args , ** kwargs ):
392- kwargs ["learning_rate" ] = kwargs .get ("learning_rate" , 0.001 ) * \
393- HvdContext .get ().world_size
394- super (HvdOptimizer , self ).__init__ (* args , ** kwargs )
391+ def __init__ (self , learning_rate = 0.001 , * args , ** kwargs ):
392+ learning_rate = learning_rate * HvdContext .get ().world_size
393+ super (HvdOptimizer , self ).__init__ (learning_rate , * args , ** kwargs )
395394
396- def compute_gradients (self , loss , ** kwargs ):
397- loss = hvd .allreduce (loss , op = hvd .Sum )
398- return super ().compute_gradients (loss , ** kwargs )
399-
400395 if isinstance (cls , HvdOptimizer ):
401396 return cls
402397 else :
403398 def horovod_optimizer (* args , ** kwargs ):
404- return HvdOptimizer (* args , ** kwargs )
399+ from horovod .tensorflow import DistributedOptimizer
400+ return DistributedOptimizer (HvdOptimizer (* args , ** kwargs ))
405401 return horovod_optimizer
406402
407403
@@ -478,16 +474,6 @@ def HorovodMonitoredTrainingSession(*args, **kwargs): # pylint: disable=invalid
478474 kwargs ['config' ] = wraps_session_config (kwargs .pop ('config' , None ))
479475 kwargs ['is_chief' ] = True
480476 args = list (args )
481- if args :
482- master = args [0 ]
483- if not master :
484- master = ''
485- args [0 ] = master
486- else :
487- master = kwargs .pop ('master' , None )
488- if not master :
489- master = ''
490- kwargs ['master' ] = master
491477
492478 prev_monitored_session = _monitored_session .MonitoredSession
493479 sess = fn (* args , ** kwargs )
@@ -1449,4 +1435,4 @@ def export(export_dir_base,
14491435 as_text = as_text ,
14501436 clear_devices = clear_devices ,
14511437 strip_default_attrs = strip_default_attrs ,
1452- modes = [mode ])
1438+ modes = [mode ])
0 commit comments