@@ -1771,6 +1771,7 @@ def _import_hop_while_loop(
17711771 self ._multi_result_nodes .add (node )
17721772 else :
17731773 result_types = [self ._cc .node_val_to_type (node )]
1774+
17741775 # Call the condition function with initial carries to get initial condition
17751776 cond_result_type = self ._cc .get_vtensor_type (torch .Size ([]), torch .bool )
17761777
@@ -1781,13 +1782,15 @@ def _import_hop_while_loop(
17811782 operands = carry_values ,
17821783 loc = loc ,
17831784 )
1785+
17841786 # Convert vtensor<bool> to torch.bool
17851787 bool_conv = Operation .create (
17861788 name = "torch.aten.Bool.Tensor" ,
17871789 results = [self ._cc .torch_bool_type ],
17881790 operands = [initial_cond_call .results [0 ]],
17891791 loc = loc ,
17901792 )
1793+
17911794 # Create max iterations constant (INT64_MAX)
17921795 with loc :
17931796 max_iter = _make_constant_op (
@@ -1811,6 +1814,7 @@ def _import_hop_while_loop(
18111814 block_arg_types = [self ._cc .torch_int_type ] + result_types
18121815 with loc :
18131816 loop_block = Block .create_at_start (loop_region , block_arg_types )
1817+
18141818 # Inside the loop body, call body function and condition function
18151819 with InsertionPoint (loop_block ):
18161820 # Call body function with current carry values (skip iteration counter)
@@ -1822,6 +1826,7 @@ def _import_hop_while_loop(
18221826 loc = loc ,
18231827 )
18241828 body_results = list (body_results_op .results )
1829+
18251830 # Call condition function with updated carries
18261831 cond_result_loop = Operation .create (
18271832 name = "func.call" ,
@@ -1830,20 +1835,23 @@ def _import_hop_while_loop(
18301835 operands = body_results ,
18311836 loc = loc ,
18321837 ).result
1838+
18331839 # Convert to bool
18341840 cond_bool = Operation .create (
18351841 name = "torch.aten.Bool.Tensor" ,
18361842 results = [self ._cc .torch_bool_type ],
18371843 operands = [cond_result_loop ],
18381844 loc = loc ,
18391845 ).result
1846+
18401847 # Emit loop condition with updated carries
18411848 Operation .create (
18421849 name = "torch.prim.Loop.condition" ,
18431850 results = [],
18441851 operands = [cond_bool ] + body_results ,
18451852 loc = loc ,
18461853 )
1854+
18471855 # Bind the loop results to the node
18481856 if len (result_types ) > 1 :
18491857 self ._multi_result_nodes .add (node )
0 commit comments