@@ -324,8 +324,10 @@ setInPlaceFuncArgument(BlockArgument bbArg,
324324
325325// / Remove the attribute that triggers inplace bufferization on a FuncOp
326326// / argument `bbArg`.
327- static void removeInPlaceFuncArgument (BlockArgument bbArg) {
327+ static void removeBufferizationFuncArguments (BlockArgument bbArg) {
328328 auto funcOp = cast<FuncOp>(bbArg.getOwner ()->getParentOp ());
329+ funcOp.removeArgAttr (bbArg.getArgNumber (),
330+ LinalgDialect::kBufferLayoutAttrName );
329331 funcOp.removeArgAttr (bbArg.getArgNumber (),
330332 LinalgDialect::kInplaceableAttrName );
331333}
@@ -2608,6 +2610,96 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
26082610 (void )applyPatternsAndFoldGreedily (moduleOp, std::move (patterns));
26092611}
26102612
2613+ static void
2614+ foreachCaller (const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
2615+ FuncOp callee, llvm::function_ref<void (Operation *)> doit) {
2616+ auto itCallers = callerMap.find (callee);
2617+ if (itCallers == callerMap.end ())
2618+ return ;
2619+ for (Operation *caller : itCallers->second )
2620+ doit (caller);
2621+ }
2622+
2623+ // / Postprocess the linalg.buffer_layout annotation across function boundaries.
2624+ // / This is a purely mechanical process that may later become part of a
2625+ // / separate pass with its own layout assignment heuristic.
2626+ static void layoutPostProcessing (ModuleOp moduleOp) {
2627+ SmallVector<FuncOp> orderedFuncOps;
2628+ DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
2629+ auto res = getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap);
2630+ assert (succeeded (res) && " unexpected getFuncOpsOrderedByCalls failure" );
2631+
2632+ for (FuncOp funcOp : orderedFuncOps) {
2633+ DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
2634+ foreachCaller (callerMap, funcOp, [&](Operation *caller) {
2635+ operandsPerCaller.try_emplace (caller, SmallVector<Value>());
2636+ });
2637+
2638+ SmallVector<Type> argumentTypes;
2639+ // Iterate on each function argument and check it it was marked with a
2640+ // desired layout.
2641+ for (auto it : llvm::enumerate (funcOp.getType ().getInputs ())) {
2642+ int argNumber = it.index ();
2643+ Type inputType = it.value ();
2644+ auto memrefType = inputType.dyn_cast <MemRefType>();
2645+ auto layoutAttr = funcOp.getArgAttrOfType <AffineMapAttr>(
2646+ argNumber, LinalgDialect::kBufferLayoutAttrName );
2647+ AffineMap desiredLayoutMap =
2648+ layoutAttr ? layoutAttr.getValue () : AffineMap ();
2649+ AffineMap currentLayoutMap =
2650+ memrefType ? getStridedLinearLayoutMap (memrefType) : AffineMap ();
2651+ if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
2652+ argumentTypes.push_back (inputType);
2653+ foreachCaller (callerMap, funcOp, [&](Operation *caller) {
2654+ operandsPerCaller.find (caller)->getSecond ().push_back (
2655+ caller->getOperand (argNumber));
2656+ });
2657+ continue ;
2658+ }
2659+
2660+ // Compute the buffer type with desired layout and add to input argument
2661+ // types.
2662+ MemRefType desiredMemrefType = MemRefType::get (
2663+ memrefType.getShape (), memrefType.getElementType (), desiredLayoutMap);
2664+ argumentTypes.push_back (desiredMemrefType);
2665+
2666+ // If funcOp's body is not empty, change the bbArg type and propagate.
2667+ if (!funcOp.body ().empty ()) {
2668+ BlockArgument bbArg = funcOp.getArgument (argNumber);
2669+ bbArg.setType (desiredMemrefType);
2670+ OpBuilder b (bbArg.getContext ());
2671+ b.setInsertionPointToStart (bbArg.getOwner ());
2672+ // Cast back to the original memrefType and let it canonicalize.
2673+ Value cast =
2674+ b.create <memref::CastOp>(funcOp.getLoc (), memrefType, bbArg);
2675+ bbArg.replaceAllUsesExcept (cast, cast.getDefiningOp ());
2676+ }
2677+
2678+ // Cast to desired buffer type on all callers to `funcOp`.
2679+ // TODO: on the callee side, this may even have to trigger a copy to
2680+ // change the layout. For now let the memref::CastOp fail to verify in
2681+ // such cases.
2682+ auto castArg = [&](Operation *caller) {
2683+ OpBuilder b (caller);
2684+ Value newOperand = b.create <memref::CastOp>(
2685+ funcOp.getLoc (), desiredMemrefType, caller->getOperand (argNumber));
2686+ operandsPerCaller.find (caller)->getSecond ().push_back (newOperand);
2687+ };
2688+ foreachCaller (callerMap, funcOp, castArg);
2689+ }
2690+
2691+ // Set operands with cast buffer on all callers to `funcOp`.
2692+ foreachCaller (callerMap, funcOp, [&](Operation *caller) {
2693+ caller->setOperands (operandsPerCaller.lookup (caller));
2694+ });
2695+
2696+ // Finally set the funcOp type to update the arguments.
2697+ auto newFuncType = FunctionType::get (moduleOp.getContext (), argumentTypes,
2698+ funcOp.getType ().getResults ());
2699+ funcOp.setType (newFuncType);
2700+ }
2701+ }
2702+
26112703void LinalgComprehensiveModuleBufferize::runOnOperation () {
26122704 ModuleOp moduleOp = getOperation ();
26132705 applyEnablingTransformations (moduleOp);
@@ -2672,12 +2764,16 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
26722764 }
26732765 }
26742766
2675- // Post-pass cleanup of inplaceable attributes.
2767+ // Perform a post-processing pass of layout modification at function boundary
2768+ // according to the kBufferLayoutAttrName.
2769+ layoutPostProcessing (moduleOp);
2770+
2771+ // Post-pass cleanup of inplaceable and buffer_layout attributes.
26762772 moduleOp.walk (
26772773 [&](Operation *op) { op->removeAttr (kInPlaceResultsAttrName ); });
26782774 moduleOp.walk ([&](FuncOp op) {
26792775 for (BlockArgument bbArg : op.getArguments ())
2680- removeInPlaceFuncArgument (bbArg);
2776+ removeBufferizationFuncArguments (bbArg);
26812777 });
26822778
26832779 OpPassManager cleanupPipeline (OpPassManager (" module" ));
0 commit comments