@@ -500,19 +500,19 @@ def unique_value(self):
500500
501501
502502# for more dtypes, call SparseTensorType(format, dtype)
503- def matrix (format , name = None , dtype = None ):
503+ def matrix (format , name = None , dtype = None , shape = None ):
504504 if dtype is None :
505505 dtype = config .floatX
506- type = SparseTensorType (format = format , dtype = dtype )
506+ type = SparseTensorType (format = format , dtype = dtype , shape = shape )
507507 return type (name )
508508
509509
510- def csc_matrix (name = None , dtype = None ):
511- return matrix ("csc" , name , dtype )
510+ def csc_matrix (name = None , dtype = None , shape = None ):
511+ return matrix ("csc" , name = name , dtype = dtype , shape = shape )
512512
513513
514- def csr_matrix (name = None , dtype = None ):
515- return matrix ("csr" , name , dtype )
514+ def csr_matrix (name = None , dtype = None , shape = None ):
515+ return matrix ("csr" , name = name , dtype = dtype , shape = shape )
516516
517517
518518def bsr_matrix (name = None , dtype = None ):
@@ -727,10 +727,22 @@ def make_node(self, data, indices, indptr, shape):
727727 if shape .type .ndim != 1 or shape .type .dtype not in discrete_dtypes :
728728 raise TypeError ("n_rows must be integer type" , shape , shape .type )
729729
730+ static_shape = (None , None )
731+ if (
732+ shape .owner is not None
733+ and isinstance (shape .owner .op , CSMProperties )
734+ and shape .owner .outputs [3 ] is shape
735+ ):
736+ static_shape = shape .owner .inputs [0 ].type .shape
737+
730738 return Apply (
731739 self ,
732740 [data , indices , indptr , shape ],
733- [SparseTensorType (dtype = data .type .dtype , format = self .format )()],
741+ [
742+ SparseTensorType (
743+ dtype = data .type .dtype , format = self .format , shape = static_shape
744+ )()
745+ ],
734746 )
735747
736748 def perform (self , node , inputs , outputs ):
0 commit comments