@@ -121,23 +121,89 @@ ModuleOp transform::detail::getPreloadedTransformModule(MLIRContext *context) {
121121 ->getLibraryModule ();
122122}
123123
124+ static transform::TransformOpInterface
125+ findTransformEntryPointNonRecursive (Operation *op, StringRef entryPoint) {
126+ for (Region ®ion : op->getRegions ()) {
127+ for (Block &block : region.getBlocks ()) {
128+ for (auto namedSequenceOp : block.getOps <transform::NamedSequenceOp>()) {
129+ if (namedSequenceOp.getSymName () == entryPoint) {
130+ return cast<transform::TransformOpInterface>(
131+ namedSequenceOp.getOperation ());
132+ }
133+ }
134+ }
135+ }
136+ return nullptr ;
137+ }
138+
139+ static transform::TransformOpInterface
140+ findTransformEntryPointRecursive (Operation *op, StringRef entryPoint) {
141+ transform::TransformOpInterface transform = nullptr ;
142+ op->walk <WalkOrder::PreOrder>(
143+ [&](transform::NamedSequenceOp namedSequenceOp) {
144+ if (namedSequenceOp.getSymName () == entryPoint) {
145+ transform = cast<transform::TransformOpInterface>(
146+ namedSequenceOp.getOperation ());
147+ return WalkResult::interrupt ();
148+ }
149+ return WalkResult::advance ();
150+ });
151+ return transform;
152+ }
153+
154+ // Will look for the transform's entry point favouring NamedSequenceOps
155+ // ops that exist within the operation without the need for nesting.
156+ // If no operation exists in the blocks owned by op, then it will recursively
157+ // walk the op in preorder and find the first NamedSequenceOp that matches
158+ // the entry point's name.
159+ //
160+ // This allows for the following two use cases:
161+ // 1. op is a module annotated with the transform.with_named_sequence attribute
162+ // that has an entry point in its block. E.g.,
163+ //
164+ // ```mlir
165+ // module {transform.with_named_sequence} {
166+ // transform.named_sequence @__transform_main(%arg0 : !transform.any_op) ->
167+ // () {
168+ // transform.yield
169+ // }
170+ // }
171+ // ```
172+ //
173+ // 2. op is a program which contains a nested module annotated with the
174+ // transform.with_named_sequence attribute. E.g.,
175+ //
176+ // ```mlir
177+ // module {
178+ // func.func @foo () {
179+ // }
180+ //
181+ // module {transform.with_named_sequence} {
182+ // transform.named_sequence @__transform_main(%arg0 : !transform.any_op)
183+ // -> () {
184+ // transform.yield
185+ // }
186+ // }
187+ // }
188+ // ```
189+ static transform::TransformOpInterface
190+ findTransformEntryPointInOp (Operation *op, StringRef entryPoint) {
191+ transform::TransformOpInterface transform =
192+ findTransformEntryPointNonRecursive (op, entryPoint);
193+ if (!transform)
194+ transform = findTransformEntryPointRecursive (op, entryPoint);
195+ return transform;
196+ }
197+
124198transform::TransformOpInterface
125199transform::detail::findTransformEntryPoint (Operation *root, ModuleOp module ,
126200 StringRef entryPoint) {
127201 SmallVector<Operation *, 2 > l{root};
128202 if (module )
129203 l.push_back (module );
130204 for (Operation *op : l) {
131- transform::TransformOpInterface transform = nullptr ;
132- op->walk <WalkOrder::PreOrder>(
133- [&](transform::NamedSequenceOp namedSequenceOp) {
134- if (namedSequenceOp.getSymName () == entryPoint) {
135- transform = cast<transform::TransformOpInterface>(
136- namedSequenceOp.getOperation ());
137- return WalkResult::interrupt ();
138- }
139- return WalkResult::advance ();
140- });
205+ TransformOpInterface transform =
206+ findTransformEntryPointInOp (op, entryPoint);
141207 if (transform)
142208 return transform;
143209 }
0 commit comments