Skip to content

Commit 5c9b3e6

Browse files
New Feature: Adds support for composed kernel reordering for Mod Switch (#101)
add support for composed kernel reordering for Mod.
1 parent b38497e commit 5c9b3e6

File tree

3 files changed

+63
-17
lines changed

3 files changed

+63
-17
lines changed

p-isa_tools/kerngen/kernel_optimization/loops.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,54 @@
1010
from high_parser.pisa_operations import PIsaOp, Comment
1111

1212

13+
def remove_comments(pisa_list: list[PIsaOp]) -> list[PIsaOp]:
14+
"""Remove comments from a list of PIsaOp instructions.
15+
16+
Args:
17+
pisa_list: List of PIsaOp instructions
18+
19+
Returns:
20+
List of PIsaOp instructions without comments
21+
"""
22+
return [pisa for pisa in pisa_list if not isinstance(pisa, Comment)]
23+
24+
25+
def split_by_reorderable(pisa_list: list[PIsaOp]) -> tuple[list[PIsaOp], list[PIsaOp]]:
26+
"""Split a list of PIsaOp instructions into reorderable and non-reorderable groups.
27+
28+
Args:
29+
pisa_list: List of PIsaOp instructions
30+
31+
Returns:
32+
Tuple containing two lists:
33+
- reorderable: Instructions that can be reordered
34+
- non_reorderable: Instructions that cannot be reordered
35+
"""
36+
37+
reorderable = []
38+
non_reorderable = []
39+
is_reorderable = False
40+
41+
for pisa in pisa_list:
42+
# if the pisa is a comment and it contains <reorderable> tag, treat the following pisa as reorderable until a </reorderable> tag is found.
43+
if isinstance(pisa, Comment):
44+
if "<reorderable>" in pisa.line:
45+
is_reorderable = True
46+
elif "</reorderable>" in pisa.line:
47+
is_reorderable = False
48+
49+
if is_reorderable:
50+
reorderable.append(pisa)
51+
else:
52+
non_reorderable.append(pisa)
53+
54+
# if reoderable is empty, return non_reorderable as reorderable
55+
if not reorderable:
56+
reorderable = non_reorderable
57+
non_reorderable = []
58+
return remove_comments(reorderable), remove_comments(non_reorderable)
59+
60+
1361
def loop_interchange(
1462
pisa_list: list[PIsaOp],
1563
primary_key: LoopKey | None = LoopKey.PART,
@@ -52,7 +100,7 @@ def get_sort_key(pisa: PIsaOp) -> tuple:
52100
return (primary_value,)
53101

54102
# Filter out comments
55-
pisa_list_wo_comments = [p for p in pisa_list if not isinstance(p, Comment)]
103+
pisa_list_wo_comments = remove_comments(pisa_list)
56104
# Sort based on primary and optional secondary keys
57105
pisa_list_wo_comments.sort(key=get_sort_key)
58106
return pisa_list_wo_comments

p-isa_tools/kerngen/kerngraph.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import argparse
3434
import sys
3535
from kernel_parser.parser import KernelParser
36-
from kernel_optimization.loops import loop_interchange
36+
from kernel_optimization.loops import loop_interchange, split_by_reorderable
3737
from const.options import LoopKey
3838
from pisa_generators.basic import mixed_to_pisa_ops
3939
from high_parser.config import Config
@@ -52,14 +52,7 @@ def parse_args():
5252
nargs="*",
5353
default=[],
5454
# Composition high ops such are ntt, mod, and relin are not currently supported
55-
choices=[
56-
"add",
57-
"sub",
58-
"mul",
59-
"muli",
60-
"copy",
61-
"ntt",
62-
], # currently supports single ops
55+
choices=["add", "sub", "mul", "muli", "copy", "ntt", "intt", "mod"],
6356
help="List of high_op names",
6457
)
6558
parser.add_argument(
@@ -111,11 +104,16 @@ def main(args):
111104
if args.target and any(
112105
target.lower() in str(kernel).lower() for target in args.target
113106
):
114-
kernel = loop_interchange(
115-
kernel.to_pisa(),
116-
primary_key=args.primary,
117-
secondary_key=args.secondary,
107+
reorderable, non_reorderable = split_by_reorderable(kernel.to_pisa())
108+
kernel = non_reorderable
109+
kernel.append(
110+
loop_interchange(
111+
reorderable,
112+
primary_key=args.primary,
113+
secondary_key=args.secondary,
114+
)
118115
)
116+
119117
for pisa in mixed_to_pisa_ops(kernel):
120118
print(pisa)
121119
else:

p-isa_tools/kerngen/pisa_generators/mod.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""Module containing conversions or operations from isa to p-isa."""
@@ -89,7 +89,7 @@ def generate_mod_stages() -> list[Stage]:
8989
stages.append(
9090
Stage(
9191
[
92-
Comment("Mod Stage 1"),
92+
Comment("Mod Stage 1 <reorderable>"),
9393
muli_last_half(
9494
self.context,
9595
temp_input_remaining_rns,
@@ -194,7 +194,7 @@ def generate_mod_stages() -> list[Stage]:
194194
+ stages[2].pisa_ops
195195
+ [
196196
Muli(self.context, self.output, temp_input_remaining_rns, iq),
197-
Comment("End of mod kernel"),
197+
Comment("End of mod kernel </reorderable>"),
198198
]
199199
)
200200

0 commit comments

Comments
 (0)