Skip to content

Commit 4cce607

Browse files
janselpytorchmergebot
authored andcommitted
Move TestIndexingSimplification to its own file (pytorch#97941)
test_torchinductor has gotten too big (almost 10k lines), this stack is trying to split it into smaller pieces. Pull Request resolved: pytorch#97941 Approved by: https://github.com/ngimel
1 parent 94bae36 commit 4cce607

File tree

2 files changed

+209
-211
lines changed

2 files changed

+209
-211
lines changed

test/inductor/test_indexing.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Owner(s): ["module: inductor"]
2+
import sympy
3+
4+
from torch._inductor.codegen.cpp import cexpr
5+
from torch._inductor.codegen.triton import texpr
6+
from torch._inductor.codegen.wrapper import pexpr
7+
8+
from torch._inductor.ir import ModularIndexing
9+
from torch._inductor.sizevars import SizeVarAllocator
10+
from torch.fx.experimental.symbolic_shapes import FloorDiv
11+
from torch.testing._internal.common_utils import TestCase as TorchTestCase
12+
13+
14+
class TestIndexingSimplification(TorchTestCase):
15+
def test_indexing_simplification(self):
16+
sizevars = SizeVarAllocator()
17+
i0 = sympy.Symbol("i0", integer=True)
18+
i1 = sympy.Symbol("i1", integer=True)
19+
i2 = sympy.Symbol("i2", integer=True)
20+
r3 = sympy.Symbol("r3", integer=True)
21+
22+
var_ranges = {i0: 3136, i1: 64, i2: 32, r3: 3}
23+
expr = (
24+
128 * i2
25+
+ ModularIndexing(i1, 1, 64)
26+
+ 64 * ModularIndexing(i1 + 64 * r3, 64, 2)
27+
)
28+
# check that `i1//64` is removed when i1 is always less than 64,
29+
# and the next simplificaton doesn't happen
30+
self.assertEqual(
31+
sizevars.simplify_with_ranges(expr, var_ranges),
32+
i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2),
33+
)
34+
# all the modular indexing should be removed when the body cant be larger than the modulus
35+
var_ranges[r3] = 2
36+
self.assertEqual(
37+
sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3
38+
)
39+
# if there are negative terms in ModularIndexing base, we cannot replace it with FloorDiv
40+
expr = ModularIndexing(i1 - 15, 1, 64)
41+
self.assertEqual(
42+
sizevars.simplify_with_ranges(expr, var_ranges),
43+
ModularIndexing(i1 - 15, 1, 64),
44+
)
45+
# small terms should be kept if the rest is not guaranteed to be divisible
46+
self.assertEqual(
47+
sizevars.simplify_with_ranges(FloorDiv(r3 + i2 + i1, 32), var_ranges),
48+
FloorDiv(r3 + i2 + i1, 32),
49+
)
50+
51+
expr = ModularIndexing(2 * i2 + r3, 1, 64)
52+
# modular indexing is removed if base is smaller than modulo
53+
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3)
54+
55+
# check the same thing but with symbolic divisor
56+
self.assertEqual(FloorDiv(r3 * i0, r3), i0)
57+
self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10))
58+
59+
# (10*i) % 10 is always zero and should get optimized away
60+
self.assertEqual(
61+
ModularIndexing(i0 + i1 * 10, 1, 10), ModularIndexing(i0, 1, 10)
62+
)
63+
64+
# ((20*i)//2) % 10 is always zero and should get optimized away
65+
self.assertEqual(
66+
ModularIndexing(i0 + i1 * 20, 2, 10), ModularIndexing(i0, 2, 10)
67+
)
68+
69+
# the same things happens with symbolic divisor
70+
self.assertEqual(
71+
ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3)
72+
)
73+
74+
# if there are negative terms, we cannot optimize away zero terms due to https://github.com/openai/triton/issues/619
75+
self.assertEqual(
76+
ModularIndexing(-i0 + i1 * 20, 2, 10), ModularIndexing(-i0 + i1 * 20, 2, 10)
77+
)
78+
self.assertEqual(
79+
ModularIndexing(-15 + i1 * 20, 2, 10), ModularIndexing(-15 + i1 * 20, 2, 10)
80+
)
81+
82+
# Constant fold from divisor into base
83+
self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10))
84+
self.assertEqual(FloorDiv(i0 * 4, 2), i0 * 2)
85+
86+
# Nested modular indexing is correctly simplified
87+
var_ranges = {"i1": 13, "i2": 121}
88+
expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28)
89+
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
90+
expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28)
91+
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
92+
var_ranges = {"i2": 784}
93+
expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4)
94+
expected = FloorDiv(ModularIndexing(i2, 1, 28), 7)
95+
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected)
96+
expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4)
97+
self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr)
98+
99+
def test_indexing_join(self):
100+
sizevars = SizeVarAllocator()
101+
i0 = sympy.Symbol("i0", integer=True)
102+
i1 = sympy.Symbol("i1", integer=True)
103+
i2 = sympy.Symbol("i2", integer=True)
104+
105+
# join two ModularIndexing calls into one larger one when possible
106+
expr1 = ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4)
107+
self.assertEqual(
108+
sizevars.simplify_with_ranges(expr1, {}), ModularIndexing(i0, 1, 128)
109+
)
110+
111+
# it should also work with a scale
112+
self.assertEqual(
113+
sizevars.simplify_with_ranges(2 * expr1, {}),
114+
2 * ModularIndexing(i0, 1, 128),
115+
)
116+
117+
# it should work when divisor is not 1
118+
expr2 = ModularIndexing(i0, 3, 32) + 32 * ModularIndexing(i0, 32 * 3, 4)
119+
simplified = sizevars.simplify_with_ranges(expr2, {})
120+
self.assertEqual(simplified, ModularIndexing(i0, 3, 128))
121+
self.assertEqual(expr2.subs({i0: 39485}), simplified.subs({i0: 39485}))
122+
123+
# it should not happen in this case as the modulus is wrong
124+
expr3 = ModularIndexing(i0, 1, 30) + 32 * ModularIndexing(i0, 32, 4)
125+
self.assertEqual(sizevars.simplify_with_ranges(expr3, {}), expr3)
126+
127+
# check that it also works with a modulus>1
128+
expr4 = ModularIndexing(i0, 10, i1) + i1 * ModularIndexing(i0, i1 * 10, i2)
129+
res0 = expr4.subs({i0: 24056, i1: 13, i2: 19})
130+
simplified = sizevars.simplify_with_ranges(expr4, {})
131+
res1 = simplified.subs({i0: 24056, i1: 13, i2: 19})
132+
self.assertEqual(res0, res1)
133+
self.assertEqual(simplified, ModularIndexing(i0, 10, i1 * i2))
134+
135+
# and also works with an offset
136+
self.assertEqual(
137+
sizevars.simplify_with_ranges(expr4 + 10, {}),
138+
ModularIndexing(i0, 10, i1 * i2) + 10,
139+
)
140+
141+
# works for ModularIndexing + FloorDiv
142+
expr5 = 197 * FloorDiv(i0, 197) + ModularIndexing(i0, 1, 197)
143+
simplified = sizevars.simplify_with_ranges(expr5, {})
144+
self.assertEqual(simplified, i0)
145+
self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485}))
146+
147+
# works with a scale
148+
self.assertEqual(
149+
sizevars.simplify_with_ranges(2 * expr5, {}),
150+
2 * i0,
151+
)
152+
153+
# divisor != 1
154+
expr6 = 197 * FloorDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197)
155+
simplified = sizevars.simplify_with_ranges(expr6, {})
156+
self.assertEqual(simplified, FloorDiv(i0, 3))
157+
self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485}))
158+
159+
160+
class ExprPrinterTests(TorchTestCase):
161+
def test_print_pow(self):
162+
s1 = sympy.Symbol("foo", integer=True)
163+
s2 = sympy.Symbol("bar", integer=True)
164+
s3 = sympy.Symbol("baz", integer=True)
165+
166+
cases = (
167+
# expr, result
168+
# Test exprs.
169+
(
170+
s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1),
171+
lambda c: f"((-1)*({c}/((-1) + (2*foo)))) + (foo*({c}/((-1) + (2*foo))))",
172+
),
173+
(s1 / (s2 - s3), lambda c: f"foo*({c}/(bar + ((-1)*baz)))"),
174+
# Test Pow directly.
175+
(
176+
sympy.Pow(s1 + s2, 0),
177+
lambda _: "1",
178+
), # note: simplified before _print_Pow
179+
(
180+
sympy.Pow(s1 + s2, -3),
181+
lambda c: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))",
182+
),
183+
(sympy.Pow(s1 + s2, 2), lambda _: "(bar + foo)*(bar + foo)"),
184+
)
185+
186+
for expr, result in cases:
187+
self.assertEqual(cexpr(expr), result(1.0)) # 1.0 for FP div
188+
self.assertEqual(texpr(expr), result(1))
189+
self.assertEqual(pexpr(expr), result(1))
190+
191+
def test_print_floor(self):
192+
s1 = sympy.Symbol("s1", integer=False)
193+
expr = sympy.floor(s1)
194+
self.assertEqual(texpr(expr), "tl.math.floor(s1)")
195+
self.assertEqual(pexpr(expr), "math.floor(s1)")
196+
197+
def test_print_ceil(self):
198+
s1 = sympy.Symbol("s1", integer=False)
199+
expr = sympy.ceiling(s1)
200+
self.assertEqual(pexpr(expr), "math.ceil(s1)")
201+
202+
203+
if __name__ == "__main__":
204+
from torch._dynamo.test_case import run_tests
205+
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
206+
207+
if HAS_CPU or HAS_CUDA:
208+
run_tests("sympy")

0 commit comments

Comments
 (0)