Skip to content

Commit 94bae36

Browse files
lantiankaikaipytorchmergebot
authored andcommitted
Fix strip_function_call in GuardBuilder (pytorch#97810)
repo: from pytorch#92670 this address one of the bug for TorchDynamo pytest ./generated/test_PeterouZh_CIPS_3D.py -k test_003 Issue: In GuardBuilder, when parsing argnames with "getattr(a.layers[slice(2)][0]._abc, '0')" it returns "getattr(a", where it suppose to return "a", and thus causing SyntaxError. This PR fix the regex and add couple test cases. Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#97810 Approved by: https://github.com/yanboliang
1 parent ffd76d1 commit 94bae36

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

test/dynamo/test_misc.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5231,6 +5231,21 @@ def forward(self, input):
52315231
prof.report()
52325232
)
52335233

5234+
def test_guards_strip_function_call(self):
5235+
from torch._dynamo.guards import strip_function_call
5236+
5237+
test_case = [
5238+
("___odict_getitem(a, 1)", "a"),
5239+
("a.layers[slice(2)][0]._xyz", "a"),
5240+
("getattr(a.layers[slice(2)][0]._abc, '0')", "a"),
5241+
("getattr(getattr(a.x[3], '0'), '3')", "a"),
5242+
("a.layers[slice(None, -1, None)][0]._xyz", "a"),
5243+
("a.layers[func('offset', -1, None)][0]._xyz", "a"),
5244+
]
5245+
# strip_function_call should extract the object from the string.
5246+
for name, expect_obj in test_case:
5247+
self.assertEqual(strip_function_call(name), expect_obj)
5248+
52345249

52355250
class CustomFunc1(torch.autograd.Function):
52365251
@staticmethod

torch/_dynamo/guards.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,23 @@
6767
def strip_function_call(name):
6868
"""
6969
"___odict_getitem(a, 1)" => "a"
70+
"a.layers[slice(2)][0]._xyz" ==> "a"
71+
"getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a"
72+
"getattr(getattr(a.x[3], '0'), '3')" ==> "a"
73+
"a.layers[slice(None, -1, None)][0]._xyz" ==> "a"
7074
"""
71-
m = re.search(r"([a-z0-9_]+)\(([^(),]+)[^()]*\)", name)
72-
if m and m.group(1) != "slice":
73-
return strip_function_call(m.group(2))
75+
# recursively find valid object name in fuction
76+
valid_name = re.compile("[A-Za-z_].*")
77+
curr = ""
78+
for char in name:
79+
if char in " (":
80+
curr = ""
81+
elif char in "),[]":
82+
if curr and curr != "None" and valid_name.match(curr):
83+
return strip_function_call(curr)
84+
else:
85+
curr += char
86+
7487
return strip_getattr_getitem(name)
7588

7689

0 commit comments

Comments
 (0)