@@ -61,8 +61,8 @@ def test_debugprint_sitsot():
6161
6262 Scan{scan_fn, while_loop=False, inplace=none} [id C]
6363 ← Mul [id W] (inner_out_sit_sot-0)
64- ├─ *0-<TensorType (float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
65- └─ *1-<TensorType (float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
64+ ├─ *0-<Vector (float64, shape= (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
65+ └─ *1-<Vector (float64, shape= (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
6666
6767 for truth , out in zip (expected_output .split ("\n " ), lines ):
6868 assert truth .strip () == out .strip ()
@@ -119,8 +119,8 @@ def test_debugprint_sitsot_no_extra_info():
119119
120120 Scan{scan_fn, while_loop=False, inplace=none} [id C]
121121 ← Mul [id W]
122- ├─ *0-<TensorType (float64, (?,))> [id X] -> [id E]
123- └─ *1-<TensorType (float64, (?,))> [id Y] -> [id M]"""
122+ ├─ *0-<Vector (float64, shape= (?,))> [id X] -> [id E]
123+ └─ *1-<Vector (float64, shape= (?,))> [id Y] -> [id M]"""
124124
125125 for truth , out in zip (expected_output .split ("\n " ), lines ):
126126 assert truth .strip () == out .strip ()
@@ -185,10 +185,10 @@ def test_debugprint_nitsot():
185185
186186 Scan{scan_fn, while_loop=False, inplace=none} [id B]
187187 ← Mul [id X] (inner_out_nit_sot-0)
188- ├─ *0-<TensorType (float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
188+ ├─ *0-<Scalar (float64, shape= ())> [id Y] -> [id S] (inner_in_seqs-0)
189189 └─ Pow [id Z]
190- ├─ *2-<TensorType (float64, ())> [id BA] -> [id W] (inner_in_non_seqs-0)
191- └─ *1-<TensorType (int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
190+ ├─ *2-<Scalar (float64, shape= ())> [id BA] -> [id W] (inner_in_non_seqs-0)
191+ └─ *1-<Scalar (int64, shape= ())> [id BB] -> [id U] (inner_in_seqs-1)"""
192192
193193 for truth , out in zip (expected_output .split ("\n " ), lines ):
194194 assert truth .strip () == out .strip ()
@@ -265,22 +265,22 @@ def compute_A_k(A, k):
265265 Scan{scan_fn, while_loop=False, inplace=none} [id B]
266266 ← Mul [id Y] (inner_out_nit_sot-0)
267267 ├─ ExpandDims{axis=0} [id Z]
268- │ └─ *0-<TensorType (float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
268+ │ └─ *0-<Scalar (float64, shape= ())> [id BA] -> [id S] (inner_in_seqs-0)
269269 └─ Pow [id BB]
270270 ├─ Subtensor{i} [id BC]
271271 │ ├─ Subtensor{start:} [id BD]
272272 │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
273- │ │ │ ├─ *3-<TensorType (int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
273+ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
274274 │ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
275275 │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
276276 │ │ │ │ │ ├─ Add [id BI]
277- │ │ │ │ │ │ ├─ *3-<TensorType (int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
277+ │ │ │ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BF] -> [id X] (inner_in_non_seqs-1)
278278 │ │ │ │ │ │ └─ Subtensor{i} [id BJ]
279279 │ │ │ │ │ │ ├─ Shape [id BK]
280280 │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
281281 │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
282282 │ │ │ │ │ │ │ └─ Second [id BN]
283- │ │ │ │ │ │ │ ├─ *2-<TensorType (float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
283+ │ │ │ │ │ │ │ ├─ *2-<Vector (float64, shape= (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
284284 │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
285285 │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
286286 │ │ │ │ │ │ └─ ScalarConstant{0} [id BR]
@@ -294,16 +294,16 @@ def compute_A_k(A, k):
294294 │ │ │ │ └─ ScalarFromTensor [id BV]
295295 │ │ │ │ └─ Subtensor{i} [id BJ]
296296 │ │ │ │ └─ ···
297- │ │ │ └─ *2-<TensorType (float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
297+ │ │ │ └─ *2-<Vector (float64, shape= (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
298298 │ │ └─ ScalarConstant{1} [id BW]
299299 │ └─ ScalarConstant{-1} [id BX]
300300 └─ ExpandDims{axis=0} [id BY]
301- └─ *1-<TensorType (int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
301+ └─ *1-<Scalar (int64, shape= ())> [id BZ] -> [id U] (inner_in_seqs-1)
302302
303303 Scan{scan_fn, while_loop=False, inplace=none} [id BE]
304304 ← Mul [id CA] (inner_out_sit_sot-0)
305- ├─ *0-<TensorType (float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
306- └─ *1-<TensorType (float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
305+ ├─ *0-<Vector (float64, shape= (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
306+ └─ *1-<Vector (float64, shape= (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
307307
308308 for truth , out in zip (expected_output .split ("\n " ), lines ):
309309 assert truth .strip () == out .strip ()
@@ -356,28 +356,28 @@ def compute_A_k(A, k):
356356 Inner graphs:
357357
358358 Scan{scan_fn, while_loop=False, inplace=none} [id E]
359- → *0-<TensorType (float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
360- → *1-<TensorType (int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
361- → *2-<TensorType (float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
362- → *3-<TensorType (int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
359+ → *0-<Scalar (float64, shape= ())> [id Y] -> [id U] (inner_in_seqs-0)
360+ → *1-<Scalar (int64, shape= ())> [id Z] -> [id W] (inner_in_seqs-1)
361+ → *2-<Vector (float64, shape= (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
362+ → *3-<Scalar (int32, shape= ())> [id BB] -> [id B] (inner_in_non_seqs-1)
363363 ← Mul [id BC] (inner_out_nit_sot-0)
364364 ├─ ExpandDims{axis=0} [id BD]
365- │ └─ *0-<TensorType (float64, ())> [id Y] (inner_in_seqs-0)
365+ │ └─ *0-<Scalar (float64, shape= ())> [id Y] (inner_in_seqs-0)
366366 └─ Pow [id BE]
367367 ├─ Subtensor{i} [id BF]
368368 │ ├─ Subtensor{start:} [id BG]
369369 │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
370- │ │ │ ├─ *3-<TensorType (int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
370+ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BB] (inner_in_non_seqs-1) (n_steps)
371371 │ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
372372 │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
373373 │ │ │ │ │ ├─ Add [id BK]
374- │ │ │ │ │ │ ├─ *3-<TensorType (int32, ())> [id BB] (inner_in_non_seqs-1)
374+ │ │ │ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BB] (inner_in_non_seqs-1)
375375 │ │ │ │ │ │ └─ Subtensor{i} [id BL]
376376 │ │ │ │ │ │ ├─ Shape [id BM]
377377 │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
378378 │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
379379 │ │ │ │ │ │ │ └─ Second [id BP]
380- │ │ │ │ │ │ │ ├─ *2-<TensorType (float64, (?,))> [id BA] (inner_in_non_seqs-0)
380+ │ │ │ │ │ │ │ ├─ *2-<Vector (float64, shape= (?,))> [id BA] (inner_in_non_seqs-0)
381381 │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
382382 │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
383383 │ │ │ │ │ │ └─ ScalarConstant{0} [id BS]
@@ -391,18 +391,18 @@ def compute_A_k(A, k):
391391 │ │ │ │ └─ ScalarFromTensor [id BW]
392392 │ │ │ │ └─ Subtensor{i} [id BL]
393393 │ │ │ │ └─ ···
394- │ │ │ └─ *2-<TensorType (float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
394+ │ │ │ └─ *2-<Vector (float64, shape= (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
395395 │ │ └─ ScalarConstant{1} [id BX]
396396 │ └─ ScalarConstant{-1} [id BY]
397397 └─ ExpandDims{axis=0} [id BZ]
398- └─ *1-<TensorType (int64, ())> [id Z] (inner_in_seqs-1)
398+ └─ *1-<Scalar (int64, shape= ())> [id Z] (inner_in_seqs-1)
399399
400400 Scan{scan_fn, while_loop=False, inplace=none} [id BH]
401- → *0-<TensorType (float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
402- → *1-<TensorType (float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
401+ → *0-<Vector (float64, shape= (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
402+ → *1-<Vector (float64, shape= (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
403403 ← Mul [id CC] (inner_out_sit_sot-0)
404- ├─ *0-<TensorType (float64, (?,))> [id CA] (inner_in_sit_sot-0)
405- └─ *1-<TensorType (float64, (?,))> [id CB] (inner_in_non_seqs-0)"""
404+ ├─ *0-<Vector (float64, shape= (?,))> [id CA] (inner_in_sit_sot-0)
405+ └─ *1-<Vector (float64, shape= (?,))> [id CB] (inner_in_non_seqs-0)"""
406406
407407 for truth , out in zip (expected_output .split ("\n " ), lines ):
408408 assert truth .strip () == out .strip ()
@@ -440,7 +440,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
440440 │ │ │ │ └─ Subtensor{i} [id H]
441441 │ │ │ │ ├─ Shape [id I]
442442 │ │ │ │ │ └─ Subtensor{:stop} [id J]
443- │ │ │ │ │ ├─ <TensorType (int64, (?,))> [id K]
443+ │ │ │ │ │ ├─ <Vector (int64, shape= (?,))> [id K]
444444 │ │ │ │ │ └─ ScalarConstant{2} [id L]
445445 │ │ │ │ └─ ScalarConstant{0} [id M]
446446 │ │ │ ├─ Subtensor{:stop} [id J]
@@ -455,7 +455,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
455455 │ │ │ └─ Subtensor{i} [id R]
456456 │ │ │ ├─ Shape [id S]
457457 │ │ │ │ └─ Subtensor{:stop} [id T]
458- │ │ │ │ ├─ <TensorType (int64, (?,))> [id U]
458+ │ │ │ │ ├─ <Vector (int64, shape= (?,))> [id U]
459459 │ │ │ │ └─ ScalarConstant{2} [id V]
460460 │ │ │ └─ ScalarConstant{0} [id W]
461461 │ │ ├─ Subtensor{:stop} [id T]
@@ -473,11 +473,11 @@ def fn(a_m2, a_m1, b_m2, b_m1):
473473
474474 Scan{scan_fn, while_loop=False, inplace=none} [id C]
475475 ← Add [id BB] (inner_out_mit_sot-0)
476- ├─ *1-<TensorType (int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
477- └─ *0-<TensorType (int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
476+ ├─ *1-<Scalar (int64, shape= ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
477+ └─ *0-<Scalar (int64, shape= ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
478478 ← Add [id BE] (inner_out_mit_sot-1)
479- ├─ *3-<TensorType (int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
480- └─ *2-<TensorType (int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)"""
479+ ├─ *3-<Scalar (int64, shape= ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
480+ └─ *2-<Scalar (int64, shape= ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)"""
481481
482482 for truth , out in zip (expected_output .split ("\n " ), lines ):
483483 assert truth .strip () == out .strip ()
@@ -601,19 +601,19 @@ def test_debugprint_mitmot():
601601 Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
602602 ← Add [id CM] (inner_out_mit_mot-0-0)
603603 ├─ Mul [id CN]
604- │ ├─ *2-<TensorType (float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
605- │ └─ *5-<TensorType (float64, (?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
606- └─ *3-<TensorType (float64, (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
604+ │ ├─ *2-<Vector (float64, shape= (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
605+ │ └─ *5-<Vector (float64, shape= (?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
606+ └─ *3-<Vector (float64, shape= (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
607607 ← Add [id CR] (inner_out_sit_sot-0)
608608 ├─ Mul [id CS]
609- │ ├─ *2-<TensorType (float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
610- │ └─ *0-<TensorType (float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
611- └─ *4-<TensorType (float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
609+ │ ├─ *2-<Vector (float64, shape= (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
610+ │ └─ *0-<Vector (float64, shape= (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
611+ └─ *4-<Vector (float64, shape= (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
612612
613613 Scan{scan_fn, while_loop=False, inplace=none} [id F]
614614 ← Mul [id CV] (inner_out_sit_sot-0)
615- ├─ *0-<TensorType (float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
616- └─ *1-<TensorType (float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
615+ ├─ *0-<Vector (float64, shape= (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
616+ └─ *1-<Vector (float64, shape= (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
617617
618618 for truth , out in zip (expected_output .split ("\n " ), lines ):
619619 assert truth .strip () == out .strip ()
@@ -643,25 +643,25 @@ def no_shared_fn(n, x_tm1, M):
643643
644644 expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
645645 ├─ TensorConstant{20000} [id B] (n_steps)
646- ├─ TensorConstant{[ 0 .. 998 19999]} [id C] (outer_in_seqs-0)
646+ ├─ TensorConstant{[ 0 ... 998 19999]} [id C] (outer_in_seqs-0)
647647 ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
648648 │ ├─ AllocEmpty{dtype='int64'} [id E] 0
649649 │ │ └─ TensorConstant{20000} [id B]
650650 │ ├─ TensorConstant{(1,) of 0} [id F]
651651 │ └─ ScalarConstant{1} [id G]
652- └─ <TensorType (float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
652+ └─ <Tensor3 (float64, shape= (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
653653
654654 Inner graphs:
655655
656656 Scan{scan_fn, while_loop=False, inplace=all} [id A]
657657 ← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
658658 ├─ TensorConstant{0} [id J]
659659 ├─ Subtensor{i, j, k} [id K]
660- │ ├─ *2-<TensorType (float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
660+ │ ├─ *2-<Tensor3 (float64, shape= (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
661661 │ ├─ ScalarFromTensor [id M]
662- │ │ └─ *0-<TensorType (int64, ())> [id N] -> [id C] (inner_in_seqs-0)
662+ │ │ └─ *0-<Scalar (int64, shape= ())> [id N] -> [id C] (inner_in_seqs-0)
663663 │ ├─ ScalarFromTensor [id O]
664- │ │ └─ *1-<TensorType (int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
664+ │ │ └─ *1-<Scalar (int64, shape= ())> [id P] -> [id D] (inner_in_sit_sot-0)
665665 │ └─ ScalarConstant{0} [id Q]
666666 └─ TensorConstant{1} [id R]
667667
0 commit comments