Skip to content

Commit 4702855

Browse files
committed
Allow single integer as TensorType shape
1 parent 792bd04 commit 4702855

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

pytensor/tensor/type.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
7171
def __init__(
7272
self,
7373
dtype: str | npt.DTypeLike,
74-
shape: Iterable[bool | int | None] | None = None,
74+
shape: Iterable[bool | int | None] | int | None = None,
7575
name: str | None = None,
7676
broadcastable: Iterable[bool] | None = None,
7777
):
@@ -99,7 +99,7 @@ def __init__(
9999
)
100100
shape = broadcastable
101101

102-
if str(dtype) == "floatX":
102+
if dtype == "floatX":
103103
self.dtype = config.floatX
104104
else:
105105
try:
@@ -118,6 +118,8 @@ def parse_bcast_and_shape(s):
118118
f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}"
119119
)
120120

121+
if isinstance(shape, int):
122+
shape = (shape,)
121123
self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
122124
self.dtype_specs() # error checking is done there
123125
self.name = name

0 commit comments

Comments
 (0)