Skip to content

Commit 6178d07

Browse files
Fixed merge newline removals
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha@gmail.com>
1 parent dfdca75 commit 6178d07

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

python/torch_mlir/extras/fx_importer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,6 +1771,7 @@ def _import_hop_while_loop(
17711771
self._multi_result_nodes.add(node)
17721772
else:
17731773
result_types = [self._cc.node_val_to_type(node)]
1774+
17741775
# Call the condition function with initial carries to get initial condition
17751776
cond_result_type = self._cc.get_vtensor_type(torch.Size([]), torch.bool)
17761777

@@ -1781,13 +1782,15 @@ def _import_hop_while_loop(
17811782
operands=carry_values,
17821783
loc=loc,
17831784
)
1785+
17841786
# Convert vtensor<bool> to torch.bool
17851787
bool_conv = Operation.create(
17861788
name="torch.aten.Bool.Tensor",
17871789
results=[self._cc.torch_bool_type],
17881790
operands=[initial_cond_call.results[0]],
17891791
loc=loc,
17901792
)
1793+
17911794
# Create max iterations constant (INT64_MAX)
17921795
with loc:
17931796
max_iter = _make_constant_op(
@@ -1811,6 +1814,7 @@ def _import_hop_while_loop(
18111814
block_arg_types = [self._cc.torch_int_type] + result_types
18121815
with loc:
18131816
loop_block = Block.create_at_start(loop_region, block_arg_types)
1817+
18141818
# Inside the loop body, call body function and condition function
18151819
with InsertionPoint(loop_block):
18161820
# Call body function with current carry values (skip iteration counter)
@@ -1822,6 +1826,7 @@ def _import_hop_while_loop(
18221826
loc=loc,
18231827
)
18241828
body_results = list(body_results_op.results)
1829+
18251830
# Call condition function with updated carries
18261831
cond_result_loop = Operation.create(
18271832
name="func.call",
@@ -1830,20 +1835,23 @@ def _import_hop_while_loop(
18301835
operands=body_results,
18311836
loc=loc,
18321837
).result
1838+
18331839
# Convert to bool
18341840
cond_bool = Operation.create(
18351841
name="torch.aten.Bool.Tensor",
18361842
results=[self._cc.torch_bool_type],
18371843
operands=[cond_result_loop],
18381844
loc=loc,
18391845
).result
1846+
18401847
# Emit loop condition with updated carries
18411848
Operation.create(
18421849
name="torch.prim.Loop.condition",
18431850
results=[],
18441851
operands=[cond_bool] + body_results,
18451852
loc=loc,
18461853
)
1854+
18471855
# Bind the loop results to the node
18481856
if len(result_types) > 1:
18491857
self._multi_result_nodes.add(node)

0 commit comments

Comments
 (0)