@@ -1112,23 +1112,24 @@ def tril(m, k=0):
11121112
11131113 Examples
11141114 --------
1115- >>> at.tril(np.arange(1,13).reshape(4,3), -1).eval()
1115+ >>> import pytensor.tensor as pt
1116+ >>> pt.tril(pt.arange(1,13).reshape((4,3)), -1).eval()
11161117 array([[ 0, 0, 0],
11171118 [ 4, 0, 0],
11181119 [ 7, 8, 0],
11191120 [10, 11, 12]])
11201121
1121- >>> at .tril(np .arange(3*4*5).reshape(3, 4, 5)).eval()
1122+ >>> pt .tril(pt .arange(3*4*5).reshape(( 3, 4, 5) )).eval()
11221123 array([[[ 0, 0, 0, 0, 0],
11231124 [ 5, 6, 0, 0, 0],
11241125 [10, 11, 12, 0, 0],
11251126 [15, 16, 17, 18, 0]],
1126-
1127+ <BLANKLINE>
11271128 [[20, 0, 0, 0, 0],
11281129 [25, 26, 0, 0, 0],
11291130 [30, 31, 32, 0, 0],
11301131 [35, 36, 37, 38, 0]],
1131-
1132+ <BLANKLINE>
11321133 [[40, 0, 0, 0, 0],
11331134 [45, 46, 0, 0, 0],
11341135 [50, 51, 52, 0, 0],
@@ -1154,23 +1155,24 @@ def triu(m, k=0):
11541155
11551156 Examples
11561157 --------
1157- >>> at.triu(np.arange(1,13).reshape(4,3), -1).eval()
1158+ >>> import pytensor.tensor as pt
1159+ >>> pt.triu(pt.arange(1, 13).reshape((4, 3)), -1).eval()
11581160 array([[ 1, 2, 3],
11591161 [ 4, 5, 6],
11601162 [ 0, 8, 9],
11611163 [ 0, 0, 12]])
11621164
1163- >>> at .triu(np.arange(3*4*5).reshape(3, 4, 5)).eval()
1165+ >>> pt .triu(np.arange(3*4*5).reshape(( 3, 4, 5) )).eval()
11641166 array([[[ 0, 1, 2, 3, 4],
11651167 [ 0, 6, 7, 8, 9],
11661168 [ 0, 0, 12, 13, 14],
11671169 [ 0, 0, 0, 18, 19]],
1168-
1170+ <BLANKLINE>
11691171 [[20, 21, 22, 23, 24],
11701172 [ 0, 26, 27, 28, 29],
11711173 [ 0, 0, 32, 33, 34],
11721174 [ 0, 0, 0, 38, 39]],
1173-
1175+ <BLANKLINE>
11741176 [[40, 41, 42, 43, 44],
11751177 [ 0, 46, 47, 48, 49],
11761178 [ 0, 0, 52, 53, 54],
@@ -2024,28 +2026,14 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable:
20242026
20252027 Examples
20262028 --------
2027- >>> import pytensor as pt
2028- >>> import numpy as np
2029- >>> x = np.arange(24).reshape((2, 3, 4))
2030- [[[ 0 1 2 3]
2031- [ 4 5 6 7]
2032- [ 8 9 10 11]]
2033-
2034- [[12 13 14 15]
2035- [16 17 18 19]
2036- [20 21 22 23]]]
2029+ >>> import pytensor.tensor as pt
2030+ >>> x = pt.arange(24).reshape((2, 3, 4))
2031+ >>> x.type.shape
2032+ (2, 3, 4)
20372033
2034+ >>> pt.matrix_transpose(x).type.shape
2035+ (2, 4, 3)
20382036
2039- >>> pt.matrix_transpose(x).eval()
2040- [[[ 0 4 8]
2041- [ 1 5 9]
2042- [ 2 6 10]
2043- [ 3 7 11]]
2044-
2045- [[12 16 20]
2046- [13 17 21]
2047- [14 18 22]
2048- [15 19 23]]]
20492037
20502038
20512039 Notes
@@ -2072,15 +2060,21 @@ class Split(COp):
20722060
20732061 Examples
20742062 --------
2075- >>> x = vector()
2076- >>> splits = lvector()
2063+ >>> from pytensor import function
2064+ >>> import pytensor.tensor as pt
2065+ >>> x = pt.vector(dtype="int")
2066+ >>> splits = pt.vector(dtype="int")
2067+
20772068 You have to declare right away how many split_points there will be.
2078- >>> ra, rb, rc = split(x, splits, n_splits = 3, axis = 0)
2069+ >>> ra, rb, rc = pt. split(x, splits, n_splits = 3, axis = 0)
20792070 >>> f = function([x, splits], [ra, rb, rc])
20802071 >>> a, b, c = f([0,1,2,3,4,5], [3, 2, 1])
2081- a == [0,1,2]
2082- b == [3, 4]
2083- c == [5]
2072+ >>> a
2073+ array([0, 1, 2])
2074+ >>> b
2075+ array([3, 4])
2076+ >>> c
2077+ array([5])
20842078
20852079 TODO: Don't make a copy in C impl
20862080 """
@@ -2329,13 +2323,22 @@ class Join(COp):
23292323
23302324 Examples
23312325 --------
2332- >>> x, y, z = tensor.matrix(), tensor.matrix(), tensor.matrix()
2333- >>> u = tensor.vector()
2326+ >>> import pytensor.tensor as pt
2327+ >>> x, y, z = pt.matrix(), pt.matrix(), pt.matrix()
2328+ >>> u = pt.vector()
2329+
2330+ >>> r = pt.join(0, x, y, z)
2331+ >>> c = pt.join(1, x, y, z)
2332+
2333+ The axis has to be an index into the shape
2334+ >>> pt.join(2, x, y, z)
2335+ Traceback (most recent call last):
2336+ ValueError: Axis value 2 is out of range for the given input dimensions
23342337
2335- >>> r = join(0, x, y, z)
2336- >>> c = join(1 , x, y, z )
2337- >>> join(2, x, y, z) # WRONG: the axis has to be an index into the shape
2338- >>> join(0, x, u) # WRONG: joined tensors must have the same rank
2338+ Joined tensors must have the same rank
2339+ >>> pt. join(0 , x, u )
2340+ Traceback (most recent call last):
2341+ TypeError: Only tensors with the same number of dimensions can be joined. Input ndims were: [2, 1].
23392342
23402343 """
23412344
@@ -3232,28 +3235,29 @@ class _nd_grid:
32323235
32333236 Examples
32343237 --------
3235- >>> a = at.mgrid[0:5, 0:3]
3238+ >>> import pytensor.tensor as pt
3239+ >>> a = pt.mgrid[0:5, 0:3]
32363240 >>> a[0].eval()
32373241 array([[0, 0, 0],
32383242 [1, 1, 1],
32393243 [2, 2, 2],
32403244 [3, 3, 3],
3241- [4, 4, 4]], dtype=int8 )
3245+ [4, 4, 4]])
32423246 >>> a[1].eval()
32433247 array([[0, 1, 2],
32443248 [0, 1, 2],
32453249 [0, 1, 2],
32463250 [0, 1, 2],
3247- [0, 1, 2]], dtype=int8 )
3248- >>> b = at .ogrid[0:5, 0:3]
3251+ [0, 1, 2]])
3252+ >>> b = pt .ogrid[0:5, 0:3]
32493253 >>> b[0].eval()
32503254 array([[0],
32513255 [1],
32523256 [2],
32533257 [3],
3254- [4]], dtype=int8 )
3258+ [4]])
32553259 >>> b[1].eval()
3256- array([[0, 1, 2, 3]], dtype=int8 )
3260+ array([[0, 1, 2]] )
32573261
32583262 """
32593263
@@ -3915,8 +3919,8 @@ def stacklists(arg):
39153919 >>> X = stacklists([[a, b], [c, d]])
39163920 >>> f = function([a, b, c, d], X)
39173921 >>> f(1, 2, 3, 4)
3918- array([[ 1., 2.],
3919- [ 3., 4.]], dtype=float32 )
3922+ array([[1., 2.],
3923+ [3., 4.]])
39203924
39213925 We can also stack arbitrarily shaped tensors. Here we stack matrices into
39223926 a 2 by 2 grid:
0 commit comments