Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/kirin/dialects/vmath/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from . import stmts as stmts, interp as interp
from ._dialect import dialect as dialect
from .rewrites import desugar as desugar

pi = pymath.pi
e = pymath.e
Expand Down
16 changes: 0 additions & 16 deletions src/kirin/dialects/vmath/passes.py

This file was deleted.

2 changes: 2 additions & 0 deletions src/kirin/dialects/vmath/rewrites/desugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kirin.dialects.ilist import IListType

from ..stmts import add as vadd, div as vdiv, sub as vsub, mult as vmult
from .._dialect import dialect


class DesugarBinOp(RewriteRule):
Expand Down Expand Up @@ -52,6 +53,7 @@ def replace_binop(self, node: ir.Statement) -> RewriteResult:
return RewriteResult()


@dialect.post_inference
class WalkDesugarBinop(RewriteRule):
"""
Walks DesugarBinop. Needed for correct behavior when
Expand Down
15 changes: 4 additions & 11 deletions test/dialects/vmath/test_desugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from kirin.prelude import basic
from kirin.dialects import vmath
from kirin.dialects.vmath.passes import VMathDesugar
from kirin.dialects.ilist.runtime import IList


Expand All @@ -25,7 +24,6 @@ def add_scalar_lhs():

def test_add_scalar_lhs():
# out = add_scalar_lhs()
VMathDesugar(add_scalar_lhs.dialects).unsafe_run(add_scalar_lhs)
add_scalar_lhs.print()
res = add_scalar_lhs()
assert isinstance(res, IList)
Expand All @@ -34,51 +32,46 @@ def test_add_scalar_lhs():


def test_typed_kernel_add():
VMathDesugar(add_scalar_rhs_typed.dialects).unsafe_run(add_scalar_rhs_typed)
add_scalar_rhs_typed.print()
res = add_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
assert np.allclose(np.asarray(res), np.asarray([3.1, 4.1, 5.1]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def add_two_lists():
return add_kernel(x=[0, 1, 2], y=[3, 4, 5])


def test_add_lists():
VMathDesugar(add_two_lists.dialects).unsafe_run(add_two_lists)
res = add_two_lists()
assert np.allclose(np.asarray(res), np.array([0, 1, 2, 3, 4, 5]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def sub_scalar_rhs_typed(x: IList[float, Any], y: float):
return x - y


def test_sub_scalar_typed():
VMathDesugar(sub_scalar_rhs_typed.dialects).unsafe_run(sub_scalar_rhs_typed)
res = sub_scalar_rhs_typed(IList([0, 1, 2]), 3.1)
assert np.allclose(np.asarray(res), np.asarray([-3.1, -2.1, -1.1]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def mult_scalar_lhs_typed(x: float, y: IList[float, Any]):
return x * y


def test_mult_scalar_typed():
VMathDesugar(mult_scalar_lhs_typed.dialects).unsafe_run(mult_scalar_lhs_typed)
res = mult_scalar_lhs_typed(3, IList([0, 1, 2]))
assert np.allclose(np.asarray(res), np.asarray([0, 3, 6]))


@basic.union([vmath])
@basic.union([vmath])(typeinfer=True)
def div_scalar_lhs_typed(x: float, y: IList[float, Any]):
return x / y


def test_div_scalar_typed():
VMathDesugar(div_scalar_lhs_typed.dialects).unsafe_run(div_scalar_lhs_typed)
res = div_scalar_lhs_typed(3, IList([1, 1.5, 2]))
assert np.allclose(np.asarray(res), np.asarray([3, 2, 1.5]))