|
| 1 | +# Owner(s): ["module: inductor"] |
| 2 | +import sympy |
| 3 | + |
| 4 | +from torch._inductor.codegen.cpp import cexpr |
| 5 | +from torch._inductor.codegen.triton import texpr |
| 6 | +from torch._inductor.codegen.wrapper import pexpr |
| 7 | + |
| 8 | +from torch._inductor.ir import ModularIndexing |
| 9 | +from torch._inductor.sizevars import SizeVarAllocator |
| 10 | +from torch.fx.experimental.symbolic_shapes import FloorDiv |
| 11 | +from torch.testing._internal.common_utils import TestCase as TorchTestCase |
| 12 | + |
| 13 | + |
| 14 | +class TestIndexingSimplification(TorchTestCase): |
| 15 | + def test_indexing_simplification(self): |
| 16 | + sizevars = SizeVarAllocator() |
| 17 | + i0 = sympy.Symbol("i0", integer=True) |
| 18 | + i1 = sympy.Symbol("i1", integer=True) |
| 19 | + i2 = sympy.Symbol("i2", integer=True) |
| 20 | + r3 = sympy.Symbol("r3", integer=True) |
| 21 | + |
| 22 | + var_ranges = {i0: 3136, i1: 64, i2: 32, r3: 3} |
| 23 | + expr = ( |
| 24 | + 128 * i2 |
| 25 | + + ModularIndexing(i1, 1, 64) |
| 26 | + + 64 * ModularIndexing(i1 + 64 * r3, 64, 2) |
| 27 | + ) |
| 28 | + # check that `i1//64` is removed when i1 is always less than 64, |
| 29 | + # and the next simplificaton doesn't happen |
| 30 | + self.assertEqual( |
| 31 | + sizevars.simplify_with_ranges(expr, var_ranges), |
| 32 | + i1 + 128 * i2 + 64 * ModularIndexing(r3, 1, 2), |
| 33 | + ) |
| 34 | + # all the modular indexing should be removed when the body cant be larger than the modulus |
| 35 | + var_ranges[r3] = 2 |
| 36 | + self.assertEqual( |
| 37 | + sizevars.simplify_with_ranges(expr, var_ranges), i1 + 128 * i2 + 64 * r3 |
| 38 | + ) |
| 39 | + # if there are negative terms in ModularIndexing base, we cannot replace it with FloorDiv |
| 40 | + expr = ModularIndexing(i1 - 15, 1, 64) |
| 41 | + self.assertEqual( |
| 42 | + sizevars.simplify_with_ranges(expr, var_ranges), |
| 43 | + ModularIndexing(i1 - 15, 1, 64), |
| 44 | + ) |
| 45 | + # small terms should be kept if the rest is not guaranteed to be divisible |
| 46 | + self.assertEqual( |
| 47 | + sizevars.simplify_with_ranges(FloorDiv(r3 + i2 + i1, 32), var_ranges), |
| 48 | + FloorDiv(r3 + i2 + i1, 32), |
| 49 | + ) |
| 50 | + |
| 51 | + expr = ModularIndexing(2 * i2 + r3, 1, 64) |
| 52 | + # modular indexing is removed if base is smaller than modulo |
| 53 | + self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), 2 * i2 + r3) |
| 54 | + |
| 55 | + # check the same thing but with symbolic divisor |
| 56 | + self.assertEqual(FloorDiv(r3 * i0, r3), i0) |
| 57 | + self.assertEqual(ModularIndexing(r3 * i0, r3, 10), ModularIndexing(i0, 1, 10)) |
| 58 | + |
| 59 | + # (10*i) % 10 is always zero and should get optimized away |
| 60 | + self.assertEqual( |
| 61 | + ModularIndexing(i0 + i1 * 10, 1, 10), ModularIndexing(i0, 1, 10) |
| 62 | + ) |
| 63 | + |
| 64 | + # ((20*i)//2) % 10 is always zero and should get optimized away |
| 65 | + self.assertEqual( |
| 66 | + ModularIndexing(i0 + i1 * 20, 2, 10), ModularIndexing(i0, 2, 10) |
| 67 | + ) |
| 68 | + |
| 69 | + # the same things happens with symbolic divisor |
| 70 | + self.assertEqual( |
| 71 | + ModularIndexing(i0 + i1 * i2 * r3, i2, r3), ModularIndexing(i0, i2, r3) |
| 72 | + ) |
| 73 | + |
| 74 | + # if there are negative terms, we cannot optimize away zero terms due to https://github.com/openai/triton/issues/619 |
| 75 | + self.assertEqual( |
| 76 | + ModularIndexing(-i0 + i1 * 20, 2, 10), ModularIndexing(-i0 + i1 * 20, 2, 10) |
| 77 | + ) |
| 78 | + self.assertEqual( |
| 79 | + ModularIndexing(-15 + i1 * 20, 2, 10), ModularIndexing(-15 + i1 * 20, 2, 10) |
| 80 | + ) |
| 81 | + |
| 82 | + # Constant fold from divisor into base |
| 83 | + self.assertEqual(ModularIndexing(i0 * 4, 2, 10), ModularIndexing(i0 * 2, 1, 10)) |
| 84 | + self.assertEqual(FloorDiv(i0 * 4, 2), i0 * 2) |
| 85 | + |
| 86 | + # Nested modular indexing is correctly simplified |
| 87 | + var_ranges = {"i1": 13, "i2": 121} |
| 88 | + expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784), 1, 28) |
| 89 | + self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) |
| 90 | + expr = ModularIndexing(ModularIndexing(121 * i1 + i2, 1, 784) + 1, 1, 28) |
| 91 | + self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) |
| 92 | + var_ranges = {"i2": 784} |
| 93 | + expr = ModularIndexing(ModularIndexing(i2, 1, 28), 7, 4) |
| 94 | + expected = FloorDiv(ModularIndexing(i2, 1, 28), 7) |
| 95 | + self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expected) |
| 96 | + expr = ModularIndexing(ModularIndexing(i2, 1, 28) + 1, 7, 4) |
| 97 | + self.assertEqual(sizevars.simplify_with_ranges(expr, var_ranges), expr) |
| 98 | + |
| 99 | + def test_indexing_join(self): |
| 100 | + sizevars = SizeVarAllocator() |
| 101 | + i0 = sympy.Symbol("i0", integer=True) |
| 102 | + i1 = sympy.Symbol("i1", integer=True) |
| 103 | + i2 = sympy.Symbol("i2", integer=True) |
| 104 | + |
| 105 | + # join two ModularIndexing calls into one larger one when possible |
| 106 | + expr1 = ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) |
| 107 | + self.assertEqual( |
| 108 | + sizevars.simplify_with_ranges(expr1, {}), ModularIndexing(i0, 1, 128) |
| 109 | + ) |
| 110 | + |
| 111 | + # it should also work with a scale |
| 112 | + self.assertEqual( |
| 113 | + sizevars.simplify_with_ranges(2 * expr1, {}), |
| 114 | + 2 * ModularIndexing(i0, 1, 128), |
| 115 | + ) |
| 116 | + |
| 117 | + # it should work when divisor is not 1 |
| 118 | + expr2 = ModularIndexing(i0, 3, 32) + 32 * ModularIndexing(i0, 32 * 3, 4) |
| 119 | + simplified = sizevars.simplify_with_ranges(expr2, {}) |
| 120 | + self.assertEqual(simplified, ModularIndexing(i0, 3, 128)) |
| 121 | + self.assertEqual(expr2.subs({i0: 39485}), simplified.subs({i0: 39485})) |
| 122 | + |
| 123 | + # it should not happen in this case as the modulus is wrong |
| 124 | + expr3 = ModularIndexing(i0, 1, 30) + 32 * ModularIndexing(i0, 32, 4) |
| 125 | + self.assertEqual(sizevars.simplify_with_ranges(expr3, {}), expr3) |
| 126 | + |
| 127 | + # check that it also works with a modulus>1 |
| 128 | + expr4 = ModularIndexing(i0, 10, i1) + i1 * ModularIndexing(i0, i1 * 10, i2) |
| 129 | + res0 = expr4.subs({i0: 24056, i1: 13, i2: 19}) |
| 130 | + simplified = sizevars.simplify_with_ranges(expr4, {}) |
| 131 | + res1 = simplified.subs({i0: 24056, i1: 13, i2: 19}) |
| 132 | + self.assertEqual(res0, res1) |
| 133 | + self.assertEqual(simplified, ModularIndexing(i0, 10, i1 * i2)) |
| 134 | + |
| 135 | + # and also works with an offset |
| 136 | + self.assertEqual( |
| 137 | + sizevars.simplify_with_ranges(expr4 + 10, {}), |
| 138 | + ModularIndexing(i0, 10, i1 * i2) + 10, |
| 139 | + ) |
| 140 | + |
| 141 | + # works for ModularIndexing + FloorDiv |
| 142 | + expr5 = 197 * FloorDiv(i0, 197) + ModularIndexing(i0, 1, 197) |
| 143 | + simplified = sizevars.simplify_with_ranges(expr5, {}) |
| 144 | + self.assertEqual(simplified, i0) |
| 145 | + self.assertEqual(expr5.subs({i0: 39485}), simplified.subs({i0: 39485})) |
| 146 | + |
| 147 | + # works with a scale |
| 148 | + self.assertEqual( |
| 149 | + sizevars.simplify_with_ranges(2 * expr5, {}), |
| 150 | + 2 * i0, |
| 151 | + ) |
| 152 | + |
| 153 | + # divisor != 1 |
| 154 | + expr6 = 197 * FloorDiv(i0, 197 * 3) + ModularIndexing(i0, 3, 197) |
| 155 | + simplified = sizevars.simplify_with_ranges(expr6, {}) |
| 156 | + self.assertEqual(simplified, FloorDiv(i0, 3)) |
| 157 | + self.assertEqual(expr6.subs({i0: 39485}), simplified.subs({i0: 39485})) |
| 158 | + |
| 159 | + |
| 160 | +class ExprPrinterTests(TorchTestCase): |
| 161 | + def test_print_pow(self): |
| 162 | + s1 = sympy.Symbol("foo", integer=True) |
| 163 | + s2 = sympy.Symbol("bar", integer=True) |
| 164 | + s3 = sympy.Symbol("baz", integer=True) |
| 165 | + |
| 166 | + cases = ( |
| 167 | + # expr, result |
| 168 | + # Test exprs. |
| 169 | + ( |
| 170 | + s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), |
| 171 | + lambda c: f"((-1)*({c}/((-1) + (2*foo)))) + (foo*({c}/((-1) + (2*foo))))", |
| 172 | + ), |
| 173 | + (s1 / (s2 - s3), lambda c: f"foo*({c}/(bar + ((-1)*baz)))"), |
| 174 | + # Test Pow directly. |
| 175 | + ( |
| 176 | + sympy.Pow(s1 + s2, 0), |
| 177 | + lambda _: "1", |
| 178 | + ), # note: simplified before _print_Pow |
| 179 | + ( |
| 180 | + sympy.Pow(s1 + s2, -3), |
| 181 | + lambda c: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", |
| 182 | + ), |
| 183 | + (sympy.Pow(s1 + s2, 2), lambda _: "(bar + foo)*(bar + foo)"), |
| 184 | + ) |
| 185 | + |
| 186 | + for expr, result in cases: |
| 187 | + self.assertEqual(cexpr(expr), result(1.0)) # 1.0 for FP div |
| 188 | + self.assertEqual(texpr(expr), result(1)) |
| 189 | + self.assertEqual(pexpr(expr), result(1)) |
| 190 | + |
| 191 | + def test_print_floor(self): |
| 192 | + s1 = sympy.Symbol("s1", integer=False) |
| 193 | + expr = sympy.floor(s1) |
| 194 | + self.assertEqual(texpr(expr), "tl.math.floor(s1)") |
| 195 | + self.assertEqual(pexpr(expr), "math.floor(s1)") |
| 196 | + |
| 197 | + def test_print_ceil(self): |
| 198 | + s1 = sympy.Symbol("s1", integer=False) |
| 199 | + expr = sympy.ceiling(s1) |
| 200 | + self.assertEqual(pexpr(expr), "math.ceil(s1)") |
| 201 | + |
| 202 | + |
| 203 | +if __name__ == "__main__": |
| 204 | + from torch._dynamo.test_case import run_tests |
| 205 | + from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA |
| 206 | + |
| 207 | + if HAS_CPU or HAS_CUDA: |
| 208 | + run_tests("sympy") |
0 commit comments