Skip to content

Commit d133226

Browse files
authored
register vmath desugar as post inference and remove pass (#573)
- vmath desugar rewrite (#568) is now registered as a vmath dialect post-inference rewrite rule - remove vmath/passes.py as it is no longer needed These changes are separate from (#568) since they are not compatible with kirin 0.17 (in case we want to backport).
1 parent 7982391 commit d133226

File tree

4 files changed

+7
-27
lines changed

4 files changed

+7
-27
lines changed

src/kirin/dialects/vmath/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from . import stmts as stmts, interp as interp
88
from ._dialect import dialect as dialect
9+
from .rewrites import desugar as desugar
910

1011
pi = pymath.pi
1112
e = pymath.e

src/kirin/dialects/vmath/passes.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/kirin/dialects/vmath/rewrites/desugar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from kirin.dialects.ilist import IListType
77

88
from ..stmts import add as vadd, div as vdiv, sub as vsub, mult as vmult
9+
from .._dialect import dialect
910

1011

1112
class DesugarBinOp(RewriteRule):
@@ -52,6 +53,7 @@ def replace_binop(self, node: ir.Statement) -> RewriteResult:
5253
return RewriteResult()
5354

5455

56+
@dialect.post_inference
5557
class WalkDesugarBinop(RewriteRule):
5658
"""
5759
Walks DesugarBinop. Needed for correct behavior when

test/dialects/vmath/test_desugar.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from kirin.prelude import basic
66
from kirin.dialects import vmath
7-
from kirin.dialects.vmath.passes import VMathDesugar
87
from kirin.dialects.ilist.runtime import IList
98

109

@@ -25,7 +24,6 @@ def add_scalar_lhs():
2524

2625
def test_add_scalar_lhs():
2726
# out = add_scalar_lhs()
28-
VMathDesugar(add_scalar_lhs.dialects).unsafe_run(add_scalar_lhs)
2927
add_scalar_lhs.print()
3028
res = add_scalar_lhs()
3129
assert isinstance(res, IList)
@@ -34,51 +32,46 @@ def test_add_scalar_lhs():
3432

3533

3634
def test_typed_kernel_add():
37-
VMathDesugar(add_scalar_rhs_typed.dialects).unsafe_run(add_scalar_rhs_typed)
3835
add_scalar_rhs_typed.print()
3936
res = add_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
4037
assert np.allclose(np.asarray(res), np.asarray([3.1, 4.1, 5.1]))
4138

4239

43-
@basic.union([vmath])
40+
@basic.union([vmath])(typeinfer=True)
4441
def add_two_lists():
4542
return add_kernel(x=[0, 1, 2], y=[3, 4, 5])
4643

4744

4845
def test_add_lists():
49-
VMathDesugar(add_two_lists.dialects).unsafe_run(add_two_lists)
5046
res = add_two_lists()
5147
assert np.allclose(np.asarray(res), np.array([0, 1, 2, 3, 4, 5]))
5248

5349

54-
@basic.union([vmath])
50+
@basic.union([vmath])(typeinfer=True)
5551
def sub_scalar_rhs_typed(x: IList[float, Any], y: float):
5652
return x - y
5753

5854

5955
def test_sub_scalar_typed():
60-
VMathDesugar(sub_scalar_rhs_typed.dialects).unsafe_run(sub_scalar_rhs_typed)
6156
res = sub_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
6257
assert np.allclose(np.asarray(res), np.asarray([-3.1, -2.1, -1.1]))
6358

6459

65-
@basic.union([vmath])
60+
@basic.union([vmath])(typeinfer=True)
6661
def mult_scalar_lhs_typed(x: float, y: IList[float, Any]):
6762
return x * y
6863

6964

7065
def test_mult_scalar_typed():
71-
VMathDesugar(mult_scalar_lhs_typed.dialects).unsafe_run(mult_scalar_lhs_typed)
7266
res = mult_scalar_lhs_typed(3, IList([0, 1, 2]))
7367
assert np.allclose(np.asarray(res), np.asarray([0, 3, 6]))
7468

7569

76-
@basic.union([vmath])
70+
@basic.union([vmath])(typeinfer=True)
7771
def div_scalar_lhs_typed(x: float, y: IList[float, Any]):
7872
return x / y
7973

8074

8175
def test_div_scalar_typed():
82-
VMathDesugar(div_scalar_lhs_typed.dialects).unsafe_run(div_scalar_lhs_typed)
8376
res = div_scalar_lhs_typed(3, IList([1, 1.5, 2]))
8477
assert np.allclose(np.asarray(res), np.asarray([3, 2, 1.5]))

0 commit comments

Comments
 (0)