Skip to content

Commit 992fe29

Browse files
author
git apple-llvm automerger
committed
Merge commit '2e40c567fbf5' from llvm.org/main into next
2 parents b0fa49e + 2e40c56 commit 992fe29

File tree

2 files changed

+99
-10
lines changed

2 files changed

+99
-10
lines changed

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,89 @@ ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
121121
->getLibraryModule();
122122
}
123123

124+
static transform::TransformOpInterface
125+
findTransformEntryPointNonRecursive(Operation *op, StringRef entryPoint) {
126+
for (Region &region : op->getRegions()) {
127+
for (Block &block : region.getBlocks()) {
128+
for (auto namedSequenceOp : block.getOps<transform::NamedSequenceOp>()) {
129+
if (namedSequenceOp.getSymName() == entryPoint) {
130+
return cast<transform::TransformOpInterface>(
131+
namedSequenceOp.getOperation());
132+
}
133+
}
134+
}
135+
}
136+
return nullptr;
137+
}
138+
139+
static transform::TransformOpInterface
140+
findTransformEntryPointRecursive(Operation *op, StringRef entryPoint) {
141+
transform::TransformOpInterface transform = nullptr;
142+
op->walk<WalkOrder::PreOrder>(
143+
[&](transform::NamedSequenceOp namedSequenceOp) {
144+
if (namedSequenceOp.getSymName() == entryPoint) {
145+
transform = cast<transform::TransformOpInterface>(
146+
namedSequenceOp.getOperation());
147+
return WalkResult::interrupt();
148+
}
149+
return WalkResult::advance();
150+
});
151+
return transform;
152+
}
153+
154+
// Will look for the transform's entry point favouring NamedSequenceOps
155+
// ops that exist within the operation without the need for nesting.
156+
// If no operation exists in the blocks owned by op, then it will recursively
157+
// walk the op in preorder and find the first NamedSequenceOp that matches
158+
// the entry point's name.
159+
//
160+
// This allows for the following two use cases:
161+
// 1. op is a module annotated with the transform.with_named_sequence attribute
162+
// that has an entry point in its block. E.g.,
163+
//
164+
// ```mlir
165+
// module {transform.with_named_sequence} {
166+
// transform.named_sequence @__transform_main(%arg0 : !transform.any_op) ->
167+
// () {
168+
// transform.yield
169+
// }
170+
// }
171+
// ```
172+
//
173+
// 2. op is a program which contains a nested module annotated with the
174+
// transform.with_named_sequence attribute. E.g.,
175+
//
176+
// ```mlir
177+
// module {
178+
// func.func @foo () {
179+
// }
180+
//
181+
// module {transform.with_named_sequence} {
182+
// transform.named_sequence @__transform_main(%arg0 : !transform.any_op)
183+
// -> () {
184+
// transform.yield
185+
// }
186+
// }
187+
// }
188+
// ```
189+
static transform::TransformOpInterface
190+
findTransformEntryPointInOp(Operation *op, StringRef entryPoint) {
191+
transform::TransformOpInterface transform =
192+
findTransformEntryPointNonRecursive(op, entryPoint);
193+
if (!transform)
194+
transform = findTransformEntryPointRecursive(op, entryPoint);
195+
return transform;
196+
}
197+
124198
transform::TransformOpInterface
125199
transform::detail::findTransformEntryPoint(Operation *root, ModuleOp module,
126200
StringRef entryPoint) {
127201
SmallVector<Operation *, 2> l{root};
128202
if (module)
129203
l.push_back(module);
130204
for (Operation *op : l) {
131-
transform::TransformOpInterface transform = nullptr;
132-
op->walk<WalkOrder::PreOrder>(
133-
[&](transform::NamedSequenceOp namedSequenceOp) {
134-
if (namedSequenceOp.getSymName() == entryPoint) {
135-
transform = cast<transform::TransformOpInterface>(
136-
namedSequenceOp.getOperation());
137-
return WalkResult::interrupt();
138-
}
139-
return WalkResult::advance();
140-
});
205+
TransformOpInterface transform =
206+
findTransformEntryPointInOp(op, entryPoint);
141207
if (transform)
142208
return transform;
143209
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2+
3+
module @td_module_4 attributes {transform.with_named_sequence} {
4+
module @foo_module attributes {transform.with_named_sequence} {
5+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
6+
// CHECK: IR printer: foo_module top-level
7+
transform.print {name="foo_module"}
8+
transform.yield
9+
}
10+
}
11+
module @bar_module attributes {transform.with_named_sequence} {
12+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
13+
// CHECK: IR printer: bar_module top-level
14+
transform.print {name="bar_module"}
15+
transform.yield
16+
}
17+
}
18+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) -> () {
19+
transform.include @foo_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
20+
transform.include @bar_module::@__transform_main failures(suppress) (%arg0) : (!transform.any_op) -> ()
21+
transform.yield
22+
}
23+
}

0 commit comments

Comments
 (0)