Skip to content

Commit 2aa3a06

Browse files
laithsakkaSilv3S
authored andcommitted
Always track _local_scalar_dense output in tensorify_python_scalars. (pytorch#166573)
We need to track all symbols, we used to skip u = item() and fail with ``` File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp expr_to_sym_proxy[expr] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: KeyError: u0 ``` Pull Request resolved: pytorch#166573 Approved by: https://github.com/bobrenjc93
1 parent 66064dd commit 2aa3a06

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

test/dynamo/test_misc.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14036,6 +14036,44 @@ def fuzzed_program(arg_0, sentinel):
1403614036
except Exception as e:
1403714037
self.fail(f"torch.compile failed with error: {e}")
1403814038

14039+
@torch._dynamo.config.patch(capture_scalar_outputs=True)
14040+
def test_tensorify_track_item_symint(self):
14041+
def _random_resize(image: torch.Tensor):
14042+
image_metanet = image
14043+
default_patch_size = 14
14044+
rand_cnn_resolution = (224, 256)
14045+
min_nump = rand_cnn_resolution[0] // default_patch_size
14046+
max_nump = rand_cnn_resolution[1] // default_patch_size
14047+
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
14048+
torch._check(new_nump > 0)
14049+
torch._check(new_nump * default_patch_size > 1)
14050+
14051+
image_metanet = F.interpolate(
14052+
image_metanet,
14053+
size=(new_nump * default_patch_size, new_nump * default_patch_size),
14054+
mode="bilinear",
14055+
align_corners=True,
14056+
)
14057+
img_h_new, img_w_new = image_metanet.shape[2:]
14058+
14059+
return (img_h_new, img_w_new), image_metanet
14060+
14061+
_random_resize_compiled = torch.compile(fullgraph=True)(_random_resize)
14062+
14063+
# Test the function
14064+
input_tensor = torch.rand(1, 3, 224, 224)
14065+
(h, w), output = _random_resize_compiled(input_tensor)
14066+
14067+
# Verify output properties
14068+
self.assertEqual(output.shape[0], 1)
14069+
self.assertEqual(output.shape[1], 3)
14070+
self.assertEqual(output.shape[2], h)
14071+
self.assertEqual(output.shape[3], w)
14072+
self.assertTrue(h % 14 == 0)
14073+
self.assertTrue(w % 14 == 0)
14074+
self.assertTrue(224 <= h <= 256)
14075+
self.assertTrue(224 <= w <= 256)
14076+
1403914077

1404014078
if __name__ == "__main__":
1404114079
from torch._dynamo.test_case import run_tests

torch/fx/passes/_tensorify_python_scalars.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,27 @@ def _sympy_interp(expr: sympy.Expr) -> MetaProxy:
207207
and node.target is torch.ops.aten._local_scalar_dense.default
208208
):
209209
dtype = node.args[0].meta["val"].dtype
210-
if not dtype.is_floating_point:
211-
continue
212210

213211
assert isinstance(node.args[0], fx.Node), node.args[0]
214212

215213
s = node.meta["val"].node.expr
214+
215+
expr_to_sym_proxy[s] = MetaProxy(
216+
node, tracer=tracer, fake_mode=fake_mode
217+
)
218+
219+
# only tensorify if the dtype is floating point
220+
if not dtype.is_floating_point:
221+
continue
222+
216223
expr_to_tensor_proxy[s] = MetaProxy(
217224
node.args[0], tracer=tracer, fake_mode=fake_mode
218225
)
219226
# Upcast the float tensor to torch.float64 to avoid precision problem
220227
expr_to_tensor_proxy[s] = torch.ops.prims.convert_element_type.default(
221228
expr_to_tensor_proxy[s], torch.float64
222229
)
223-
expr_to_sym_proxy[s] = MetaProxy(
224-
node, tracer=tracer, fake_mode=fake_mode
225-
)
230+
226231
# pyrefly: ignore [bad-argument-type]
227232
elif (sym_expr := _get_sym_val(node)) is not None:
228233
if sym_expr not in expr_to_sym_proxy and not isinstance(

0 commit comments

Comments
 (0)