Skip to content

Commit 475460c

Browse files
committed
Revert "Fix broadcastable -> shape deprecations"
This reverts commit 7af102d.
1 parent 18629e1 commit 475460c

File tree

7 files changed

+33
-19
lines changed

7 files changed

+33
-19
lines changed

pymc/aesaraf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def make_shared_replacements(point, vars, model):
523523
"""
524524
othervars = set(model.value_vars) - set(vars)
525525
return {
526-
var: aesara.shared(point[var.name], var.name + "_shared", shape=var.broadcastable)
526+
var: aesara.shared(point[var.name], var.name + "_shared", broadcastable=var.broadcastable)
527527
for var in othervars
528528
}
529529

pymc/distributions/continuous.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3994,9 +3994,9 @@ def make_node(self, x, h, z):
39943994
x = at.as_tensor_variable(floatX(x))
39953995
h = at.as_tensor_variable(floatX(h))
39963996
z = at.as_tensor_variable(floatX(z))
3997-
bshape = broadcast_shape(x, h, z)
3998-
shape = [False] * len(bshape)
3999-
return Apply(self, [x, h, z], [at.TensorType(aesara.config.floatX, shape=shape)()])
3997+
shape = broadcast_shape(x, h, z)
3998+
broadcastable = [] if not shape else [False] * len(shape)
3999+
return Apply(self, [x, h, z], [at.TensorType(aesara.config.floatX, broadcastable)()])
40004000

40014001
def perform(self, node, ins, outs):
40024002
x, h, z = ins[0], ins[1], ins[2]

pymc/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ class PosDefMatrix(Op):
851851
def make_node(self, x):
852852
x = at.as_tensor_variable(x)
853853
assert x.ndim == 2
854-
o = TensorType(dtype="int8", shape=[])()
854+
o = TensorType(dtype="int8", broadcastable=[])()
855855
return Apply(self, [x], [o])
856856

857857
# Python implementation:

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
self._extra_vars_shared = {}
365365
for var, value in extra_vars_and_values.items():
366366
shared = aesara.shared(
367-
value, var.name + "_shared__", shape=[s == 1 for s in value.shape]
367+
value, var.name + "_shared__", broadcastable=[s == 1 for s in value.shape]
368368
)
369369
self._extra_vars_shared[var.name] = shared
370370
givens.append((var, shared))

pymc/smc/smc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ def _logp_forw(point, out_vars, in_vars, shared):
565565
new_in_vars = []
566566
for in_var in in_vars:
567567
if in_var.dtype in discrete_types:
568-
float_var = at.TensorType("floatX", in_var.shape)(in_var.name)
568+
float_var = at.TensorType("floatX", in_var.broadcastable)(in_var.name)
569569
new_in_vars.append(float_var)
570570
replace_int_input[in_var] = at.round(float_var).astype(in_var.dtype)
571571
else:

pymc/tests/test_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def test_choose_chains(n_points, tune, expected_length, expected_n_traces):
536536
@pytest.mark.xfail(condition=(aesara.config.floatX == "float32"), reason="Fails on float32")
537537
class TestNamedSampling(SeededTest):
538538
def test_shared_named(self):
539-
G_var = shared(value=np.atleast_2d(1.0), shape=(True, False), name="G")
539+
G_var = shared(value=np.atleast_2d(1.0), broadcastable=(True, False), name="G")
540540

541541
with pm.Model():
542542
theta0 = pm.Normal(
@@ -553,7 +553,7 @@ def test_shared_named(self):
553553
assert np.isclose(res, 0.0)
554554

555555
def test_shared_unnamed(self):
556-
G_var = shared(value=np.atleast_2d(1.0), shape=(True, False))
556+
G_var = shared(value=np.atleast_2d(1.0), broadcastable=(True, False))
557557
with pm.Model():
558558
theta0 = pm.Normal(
559559
"theta0",

pymc/variational/updates.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def apply_momentum(updates, params=None, momentum=0.9):
276276
for param in params:
277277
value = param.get_value(borrow=True)
278278
velocity = aesara.shared(
279-
np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable
279+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
280280
)
281281
x = momentum * velocity + updates[param]
282282
updates[velocity] = x - param
@@ -391,7 +391,7 @@ def apply_nesterov_momentum(updates, params=None, momentum=0.9):
391391
for param in params:
392392
value = param.get_value(borrow=True)
393393
velocity = aesara.shared(
394-
np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable
394+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
395395
)
396396
x = momentum * velocity + updates[param] - param
397397
updates[velocity] = x
@@ -534,7 +534,9 @@ def adagrad(loss_or_grads=None, params=None, learning_rate=1.0, epsilon=1e-6):
534534

535535
for param, grad in zip(params, grads):
536536
value = param.get_value(borrow=True)
537-
accu = aesara.shared(np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable)
537+
accu = aesara.shared(
538+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
539+
)
538540
accu_new = accu + grad**2
539541
updates[accu] = accu_new
540542
updates[param] = param - (learning_rate * grad / at.sqrt(accu_new + epsilon))
@@ -660,7 +662,9 @@ def rmsprop(loss_or_grads=None, params=None, learning_rate=1.0, rho=0.9, epsilon
660662

661663
for param, grad in zip(params, grads):
662664
value = param.get_value(borrow=True)
663-
accu = aesara.shared(np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable)
665+
accu = aesara.shared(
666+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
667+
)
664668
accu_new = rho * accu + (one - rho) * grad**2
665669
updates[accu] = accu_new
666670
updates[param] = param - (learning_rate * grad / at.sqrt(accu_new + epsilon))
@@ -751,10 +755,12 @@ def adadelta(loss_or_grads=None, params=None, learning_rate=1.0, rho=0.95, epsil
751755
for param, grad in zip(params, grads):
752756
value = param.get_value(borrow=True)
753757
# accu: accumulate gradient magnitudes
754-
accu = aesara.shared(np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable)
758+
accu = aesara.shared(
759+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
760+
)
755761
# delta_accu: accumulate update magnitudes (recursively!)
756762
delta_accu = aesara.shared(
757-
np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable
763+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
758764
)
759765

760766
# update accu (as in rmsprop)
@@ -844,8 +850,12 @@ def adam(
844850

845851
for param, g_t in zip(params, all_grads):
846852
value = param.get_value(borrow=True)
847-
m_prev = aesara.shared(np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable)
848-
v_prev = aesara.shared(np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable)
853+
m_prev = aesara.shared(
854+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
855+
)
856+
v_prev = aesara.shared(
857+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
858+
)
849859

850860
m_t = beta1 * m_prev + (one - beta1) * g_t
851861
v_t = beta2 * v_prev + (one - beta2) * g_t**2
@@ -928,8 +938,12 @@ def adamax(
928938

929939
for param, g_t in zip(params, all_grads):
930940
value = param.get_value(borrow=True)
931-
m_prev = aesara.shared(np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable)
932-
u_prev = aesara.shared(np.zeros(value.shape, dtype=value.dtype), shape=param.broadcastable)
941+
m_prev = aesara.shared(
942+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
943+
)
944+
u_prev = aesara.shared(
945+
np.zeros(value.shape, dtype=value.dtype), broadcastable=param.broadcastable
946+
)
933947

934948
m_t = beta1 * m_prev + (one - beta1) * g_t
935949
u_t = at.maximum(beta2 * u_prev, abs(g_t))

0 commit comments

Comments
 (0)