@@ -275,9 +275,7 @@ def apply_momentum(updates, params=None, momentum=0.9):
275275
276276 for param in params :
277277 value = param .get_value (borrow = True )
278- velocity = aesara .shared (
279- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
280- )
278+ velocity = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
281279 x = momentum * velocity + updates [param ]
282280 updates [velocity ] = x - param
283281 updates [param ] = x
@@ -390,9 +388,7 @@ def apply_nesterov_momentum(updates, params=None, momentum=0.9):
390388
391389 for param in params :
392390 value = param .get_value (borrow = True )
393- velocity = aesara .shared (
394- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
395- )
391+ velocity = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
396392 x = momentum * velocity + updates [param ] - param
397393 updates [velocity ] = x
398394 updates [param ] = momentum * x + updates [param ]
@@ -534,9 +530,7 @@ def adagrad(loss_or_grads=None, params=None, learning_rate=1.0, epsilon=1e-6):
534530
535531 for param , grad in zip (params , grads ):
536532 value = param .get_value (borrow = True )
537- accu = aesara .shared (
538- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
539- )
533+ accu = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
540534 accu_new = accu + grad ** 2
541535 updates [accu ] = accu_new
542536 updates [param ] = param - (learning_rate * grad / at .sqrt (accu_new + epsilon ))
@@ -662,9 +656,7 @@ def rmsprop(loss_or_grads=None, params=None, learning_rate=1.0, rho=0.9, epsilon
662656
663657 for param , grad in zip (params , grads ):
664658 value = param .get_value (borrow = True )
665- accu = aesara .shared (
666- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
667- )
659+ accu = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
668660 accu_new = rho * accu + (one - rho ) * grad ** 2
669661 updates [accu ] = accu_new
670662 updates [param ] = param - (learning_rate * grad / at .sqrt (accu_new + epsilon ))
@@ -755,13 +747,9 @@ def adadelta(loss_or_grads=None, params=None, learning_rate=1.0, rho=0.95, epsil
755747 for param , grad in zip (params , grads ):
756748 value = param .get_value (borrow = True )
757749 # accu: accumulate gradient magnitudes
758- accu = aesara .shared (
759- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
760- )
750+ accu = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
761751 # delta_accu: accumulate update magnitudes (recursively!)
762- delta_accu = aesara .shared (
763- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
764- )
752+ delta_accu = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
765753
766754 # update accu (as in rmsprop)
767755 accu_new = rho * accu + (one - rho ) * grad ** 2
@@ -850,12 +838,8 @@ def adam(
850838
851839 for param , g_t in zip (params , all_grads ):
852840 value = param .get_value (borrow = True )
853- m_prev = aesara .shared (
854- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
855- )
856- v_prev = aesara .shared (
857- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
858- )
841+ m_prev = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
842+ v_prev = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
859843
860844 m_t = beta1 * m_prev + (one - beta1 ) * g_t
861845 v_t = beta2 * v_prev + (one - beta2 ) * g_t ** 2
@@ -938,12 +922,8 @@ def adamax(
938922
939923 for param , g_t in zip (params , all_grads ):
940924 value = param .get_value (borrow = True )
941- m_prev = aesara .shared (
942- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
943- )
944- u_prev = aesara .shared (
945- np .zeros (value .shape , dtype = value .dtype ), broadcastable = param .broadcastable
946- )
925+ m_prev = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
926+ u_prev = aesara .shared (np .zeros (value .shape , dtype = value .dtype ), shape = param .type .shape )
947927
948928 m_t = beta1 * m_prev + (one - beta1 ) * g_t
949929 u_t = at .maximum (beta2 * u_prev , abs (g_t ))
0 commit comments