Skip to content

Commit 93ad020

Browse files
committed
Propagate static shape of (some) sparse operations
1 parent 4702855 commit 93ad020

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

pytensor/sparse/basic.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -500,19 +500,19 @@ def unique_value(self):
500500

501501

502502
# for more dtypes, call SparseTensorType(format, dtype)
503-
def matrix(format, name=None, dtype=None):
503+
def matrix(format, name=None, dtype=None, shape=None):
504504
if dtype is None:
505505
dtype = config.floatX
506-
type = SparseTensorType(format=format, dtype=dtype)
506+
type = SparseTensorType(format=format, dtype=dtype, shape=shape)
507507
return type(name)
508508

509509

510-
def csc_matrix(name=None, dtype=None):
511-
return matrix("csc", name, dtype)
510+
def csc_matrix(name=None, dtype=None, shape=None):
511+
return matrix("csc", name=name, dtype=dtype, shape=shape)
512512

513513

514-
def csr_matrix(name=None, dtype=None):
515-
return matrix("csr", name, dtype)
514+
def csr_matrix(name=None, dtype=None, shape=None):
515+
return matrix("csr", name=name, dtype=dtype, shape=shape)
516516

517517

518518
def bsr_matrix(name=None, dtype=None):
@@ -727,10 +727,22 @@ def make_node(self, data, indices, indptr, shape):
727727
if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes:
728728
raise TypeError("n_rows must be integer type", shape, shape.type)
729729

730+
static_shape = (None, None)
731+
if (
732+
shape.owner is not None
733+
and isinstance(shape.owner.op, CSMProperties)
734+
and shape.owner.outputs[3] is shape
735+
):
736+
static_shape = shape.owner.inputs[0].type.shape
737+
730738
return Apply(
731739
self,
732740
[data, indices, indptr, shape],
733-
[SparseTensorType(dtype=data.type.dtype, format=self.format)()],
741+
[
742+
SparseTensorType(
743+
dtype=data.type.dtype, format=self.format, shape=static_shape
744+
)()
745+
],
734746
)
735747

736748
def perform(self, node, inputs, outputs):

0 commit comments

Comments
 (0)