@@ -28,7 +28,7 @@ def test_debugprint_sitsot():
2828
2929 expected_output = """Subtensor{i} [id A]
3030 ├─ Subtensor{start:} [id B]
31- │ ├─ for{cpu, scan_fn} [id C] (outer_out_sit_sot-0)
31+ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id C] (outer_out_sit_sot-0)
3232 │ │ ├─ k [id D] (n_steps)
3333 │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
3434 │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
@@ -59,7 +59,7 @@ def test_debugprint_sitsot():
5959
6060 Inner graphs:
6161
62- for{cpu, scan_fn} [id C]
62+ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
6363 ← Mul [id W] (inner_out_sit_sot-0)
6464 ├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
6565 └─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
@@ -86,7 +86,7 @@ def test_debugprint_sitsot_no_extra_info():
8686
8787 expected_output = """Subtensor{i} [id A]
8888 ├─ Subtensor{start:} [id B]
89- │ ├─ for{cpu, scan_fn} [id C]
89+ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
9090 │ │ ├─ k [id D]
9191 │ │ ├─ SetSubtensor{:stop} [id E]
9292 │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
@@ -117,7 +117,7 @@ def test_debugprint_sitsot_no_extra_info():
117117
118118 Inner graphs:
119119
120- for{cpu, scan_fn} [id C]
120+ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
121121 ← Mul [id W]
122122 ├─ *0-<TensorType(float64, (?,))> [id X] -> [id E]
123123 └─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
@@ -148,7 +148,7 @@ def test_debugprint_nitsot():
148148 lines = output_str .split ("\n " )
149149
150150 expected_output = """Sum{axes=None} [id A]
151- └─ for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
151+ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
152152 ├─ Minimum [id C] (outer_in_nit_sot-0)
153153 │ ├─ Subtensor{i} [id D]
154154 │ │ ├─ Shape [id E]
@@ -183,7 +183,7 @@ def test_debugprint_nitsot():
183183
184184 Inner graphs:
185185
186- for{cpu, scan_fn} [id B]
186+ Scan{ scan_fn, while_loop=False, inplace=none } [id B]
187187 ← Mul [id X] (inner_out_nit_sot-0)
188188 ├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
189189 └─ Pow [id Z]
@@ -226,7 +226,7 @@ def compute_A_k(A, k):
226226 lines = output_str .split ("\n " )
227227
228228 expected_output = """Sum{axes=None} [id A]
229- └─ for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
229+ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
230230 ├─ Minimum [id C] (outer_in_nit_sot-0)
231231 │ ├─ Subtensor{i} [id D]
232232 │ │ ├─ Shape [id E]
@@ -262,14 +262,14 @@ def compute_A_k(A, k):
262262
263263 Inner graphs:
264264
265- for{cpu, scan_fn} [id B]
265+ Scan{ scan_fn, while_loop=False, inplace=none } [id B]
266266 ← Mul [id Y] (inner_out_nit_sot-0)
267267 ├─ ExpandDims{axis=0} [id Z]
268268 │ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
269269 └─ Pow [id BB]
270270 ├─ Subtensor{i} [id BC]
271271 │ ├─ Subtensor{start:} [id BD]
272- │ │ ├─ for{cpu, scan_fn} [id BE] (outer_out_sit_sot-0)
272+ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id BE] (outer_out_sit_sot-0)
273273 │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
274274 │ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
275275 │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
@@ -300,7 +300,7 @@ def compute_A_k(A, k):
300300 └─ ExpandDims{axis=0} [id BY]
301301 └─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
302302
303- for{cpu, scan_fn} [id BE]
303+ Scan{ scan_fn, while_loop=False, inplace=none } [id BE]
304304 ← Mul [id CA] (inner_out_sit_sot-0)
305305 ├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
306306 └─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
@@ -319,7 +319,7 @@ def compute_A_k(A, k):
319319 → k [id B]
320320 → A [id C]
321321 Sum{axes=None} [id D] 13
322- └─ for{cpu, scan_fn} [id E] 12 (outer_out_nit_sot-0)
322+ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id E] 12 (outer_out_nit_sot-0)
323323 ├─ Minimum [id F] 7 (outer_in_nit_sot-0)
324324 │ ├─ Subtensor{i} [id G] 6
325325 │ │ ├─ Shape [id H] 5
@@ -355,7 +355,7 @@ def compute_A_k(A, k):
355355
356356 Inner graphs:
357357
358- for{cpu, scan_fn} [id E]
358+ Scan{ scan_fn, while_loop=False, inplace=none } [id E]
359359 → *0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
360360 → *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
361361 → *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
@@ -366,7 +366,7 @@ def compute_A_k(A, k):
366366 └─ Pow [id BE]
367367 ├─ Subtensor{i} [id BF]
368368 │ ├─ Subtensor{start:} [id BG]
369- │ │ ├─ for{cpu, scan_fn} [id BH] (outer_out_sit_sot-0)
369+ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id BH] (outer_out_sit_sot-0)
370370 │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
371371 │ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
372372 │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
@@ -397,7 +397,7 @@ def compute_A_k(A, k):
397397 └─ ExpandDims{axis=0} [id BZ]
398398 └─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
399399
400- for{cpu, scan_fn} [id BH]
400+ Scan{ scan_fn, while_loop=False, inplace=none } [id BH]
401401 → *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
402402 → *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
403403 ← Mul [id CC] (inner_out_sit_sot-0)
@@ -431,7 +431,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
431431
432432 expected_output = """Add [id A]
433433 ├─ Subtensor{start:} [id B]
434- │ ├─ for{cpu, scan_fn}.0 [id C] (outer_out_mit_sot-0)
434+ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none }.0 [id C] (outer_out_mit_sot-0)
435435 │ │ ├─ TensorConstant{5} [id D] (n_steps)
436436 │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
437437 │ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
@@ -465,13 +465,13 @@ def fn(a_m2, a_m1, b_m2, b_m1):
465465 │ │ └─ ···
466466 │ └─ ScalarConstant{2} [id Y]
467467 └─ Subtensor{start:} [id Z]
468- ├─ for{cpu, scan_fn}.1 [id C] (outer_out_mit_sot-1)
468+ ├─ Scan{ scan_fn, while_loop=False, inplace=none }.1 [id C] (outer_out_mit_sot-1)
469469 │ └─ ···
470470 └─ ScalarConstant{2} [id BA]
471471
472472 Inner graphs:
473473
474- for{cpu, scan_fn} [id C]
474+ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
475475 ← Add [id BB] (inner_out_mit_sot-0)
476476 ├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
477477 └─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
@@ -502,11 +502,11 @@ def test_debugprint_mitmot():
502502 lines = output_str .split ("\n " )
503503
504504 expected_output = """Subtensor{i} [id A]
505- ├─ for{cpu, grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
505+ ├─ Scan{ grad_of_scan_fn, while_loop=False, inplace=none }.1 [id B] (outer_out_sit_sot-0)
506506 │ ├─ Sub [id C] (n_steps)
507507 │ │ ├─ Subtensor{i} [id D]
508508 │ │ │ ├─ Shape [id E]
509- │ │ │ │ └─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
509+ │ │ │ │ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
510510 │ │ │ │ ├─ k [id G] (n_steps)
511511 │ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
512512 │ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
@@ -537,7 +537,7 @@ def test_debugprint_mitmot():
537537 │ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
538538 │ │ ├─ Subtensor{::step} [id BA]
539539 │ │ │ ├─ Subtensor{:stop} [id BB]
540- │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
540+ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
541541 │ │ │ │ │ └─ ···
542542 │ │ │ │ └─ ScalarConstant{-1} [id BC]
543543 │ │ │ └─ ScalarConstant{-1} [id BD]
@@ -547,7 +547,7 @@ def test_debugprint_mitmot():
547547 │ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
548548 │ │ ├─ Subtensor{:stop} [id BG]
549549 │ │ │ ├─ Subtensor{::step} [id BH]
550- │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
550+ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
551551 │ │ │ │ │ └─ ···
552552 │ │ │ │ └─ ScalarConstant{-1} [id BI]
553553 │ │ │ └─ ScalarConstant{-1} [id BJ]
@@ -557,14 +557,14 @@ def test_debugprint_mitmot():
557557 │ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
558558 │ │ ├─ IncSubtensor{start:} [id BM]
559559 │ │ │ ├─ Second [id BN]
560- │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
560+ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
561561 │ │ │ │ │ └─ ···
562562 │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
563563 │ │ │ │ └─ TensorConstant{0.0} [id BP]
564564 │ │ │ ├─ IncSubtensor{i} [id BQ]
565565 │ │ │ │ ├─ Second [id BR]
566566 │ │ │ │ │ ├─ Subtensor{start:} [id BS]
567- │ │ │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
567+ │ │ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
568568 │ │ │ │ │ │ │ └─ ···
569569 │ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
570570 │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
@@ -598,7 +598,7 @@ def test_debugprint_mitmot():
598598
599599 Inner graphs:
600600
601- for{cpu, grad_of_scan_fn} [id B]
601+ Scan{ grad_of_scan_fn, while_loop=False, inplace=none } [id B]
602602 ← Add [id CM] (inner_out_mit_mot-0-0)
603603 ├─ Mul [id CN]
604604 │ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
@@ -610,7 +610,7 @@ def test_debugprint_mitmot():
610610 │ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
611611 └─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
612612
613- for{cpu, scan_fn} [id F]
613+ Scan{ scan_fn, while_loop=False, inplace=none } [id F]
614614 ← Mul [id CV] (inner_out_sit_sot-0)
615615 ├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
616616 └─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
@@ -641,7 +641,7 @@ def no_shared_fn(n, x_tm1, M):
641641 # (i.e. from `Scan._fn`)
642642 out = pytensor .function ([M ], out , updates = updates , mode = "FAST_RUN" )
643643
644- expected_output = """forall_inplace,cpu,scan_fn } [id A] 2 (outer_out_sit_sot-0)
644+ expected_output = """Scan{scan_fn, while_loop=False, inplace=all } [id A] 2 (outer_out_sit_sot-0)
645645 ├─ TensorConstant{20000} [id B] (n_steps)
646646 ├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
647647 ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
@@ -653,7 +653,7 @@ def no_shared_fn(n, x_tm1, M):
653653
654654 Inner graphs:
655655
656- forall_inplace,cpu,scan_fn } [id A]
656+ Scan{scan_fn, while_loop=False, inplace=all } [id A]
657657 ← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
658658 ├─ TensorConstant{0} [id J]
659659 ├─ Subtensor{i, j, k} [id K]
0 commit comments