@@ -158,29 +158,41 @@ LinalgTilingOptions &mlir::linalg::LinalgTilingOptions::scalarizeDynamicDims() {
158158 return *this ;
159159}
160160
161- // / Pad `opOperand` using the provided `paddingValues`. Exit early for scalar
162- // / operands, if `paddingValues` contains no value for the `opOperand `, or if
163- // / `opOperand` is not defined by an ExtractSliceOp. Otherwise, try to pad the
164- // / operand even if it already has a static shape. Set `result` to the result of
165- // / the created tensor::PadOp or and return success if the operand either has
166- // / been padded to a static shape or already had a static shape and failure
167- // / otherwise.
168- static LogicalResult padOperandToSmallestStaticBoundingBox (
161+ // / Pad the `opOperand` in the `paddingDimensions` using the padding value and
162+ // / the nofold flag found in `paddingValues` and `packPaddings `, respectively.
163+ // / Exit early and return the `opOperand` value if the shape dimensions that
164+ // / match `paddingDimensions` have a static size and the nofold flag is not set.
165+ // / Otherwise, try to pad the shape dimensions that match the iterator
166+ // / dimensions `paddingDimensions` and return the tensor::PadOp result if
167+ // / padding succeeds or failure otherwise.
168+ static FailureOr<Value> padOperandToSmallestStaticBoundingBox (
169169 OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand,
170- ArrayRef<Attribute> paddingValues, ArrayRef<bool > packPaddings,
171- Value &result) {
172- // Get the shape of the operand and check if it has a dynamic shape. Only
173- // return failure if the operand is not a scalar and has a dynamic shape.
170+ ArrayRef<int64_t > paddingDimensions, ArrayRef<Attribute> paddingValues,
171+ ArrayRef<bool > packPaddings) {
172+ AffineMap indexingMap = opToPad.getTiedIndexingMap (opOperand);
174173 ArrayRef<int64_t > shape = opToPad.getShape (opOperand);
175- bool hasDynamicShape = llvm::is_contained (shape, ShapedType::kDynamicSize );
176174
177- // Cannot pad scalar operands.
178- if (shape.empty ())
179- return success ();
175+ // Collect the shape dimension that are a function of the `paddingDimensions`.
176+ llvm::SmallDenseSet<int64_t > shapeDimsToPad;
177+ for (int64_t dim : paddingDimensions)
178+ for (const auto &en : enumerate(indexingMap.getResults ()))
179+ if (en.value ().isFunctionOfDim (dim))
180+ shapeDimsToPad.insert (en.index ());
180181
181- // Cannot pad if the padding value is unknown.
182+ // Return the unpadded operand if padding to a static shape is not needed and
183+ // if the nofold flag is not set.
184+ bool nofold = opOperand->getOperandNumber () < packPaddings.size ()
185+ ? packPaddings[opOperand->getOperandNumber ()]
186+ : false ;
187+ bool hasStaticShape = llvm::none_of (shapeDimsToPad, [&](int64_t dim) {
188+ return ShapedType::isDynamic (shape[dim]);
189+ });
190+ if (!nofold && hasStaticShape)
191+ return opOperand->get ();
192+
193+ // Fail if `paddingValues` specifies no padding value.
182194 if (opOperand->getOperandNumber () >= paddingValues.size ())
183- return failure (hasDynamicShape );
195+ return failure ();
184196 Attribute paddingAttr = paddingValues[opOperand->getOperandNumber ()];
185197 Value paddingValue = b.create <arith::ConstantOp>(
186198 opToPad.getLoc (), paddingAttr.getType (), paddingAttr);
@@ -192,27 +204,31 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
192204 currOpOperand = linalgOp.getOutputOperand (result.getResultNumber ());
193205 }
194206
195- // Cannot construct a static bounding box if the `currOpOperand` is not
196- // defined by an ExtractSliceOp.
207+ // Fail if `currOpOperand` is not defined by an ExtractSliceOp.
197208 auto sliceOp = currOpOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
198209 if (!sliceOp)
199- return failure (hasDynamicShape );
210+ return failure ();
200211
201212 // Compute the dropped dimensions if `sliceOp` is ranke-reducing.
202213 llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims ();
214+ OffsetSizeAndStrideOpInterface shapedOp = sliceOp;
203215
204216 // Upper bound the `sliceOp` sizes to obtain a static bounding box.
205- SmallVector<int64_t > staticSizes;
206- staticSizes.reserve (shape.size ());
207- auto shapedOp = cast<OffsetSizeAndStrideOpInterface>(sliceOp.getOperation ());
217+ SmallVector<int64_t > paddedShape (shape.begin (), shape.end ());
218+ int64_t shapeIdx = 0 ;
208219 for (const auto &en : enumerate(shapedOp.getMixedSizes ())) {
209220 // Skip dropped dimensions.
210221 if (droppedDims.test (en.index ()))
211222 continue ;
212- // If the size is an attribute add it directly to `staticSizes`.
223+ // Skip dimensions that do not require padding.
224+ if (!shapeDimsToPad.contains (shapeIdx)) {
225+ shapeIdx++;
226+ continue ;
227+ }
228+ // If the size is an attribute add it directly to `paddedShape`.
213229 if (en.value ().is <Attribute>()) {
214- staticSizes. push_back (
215- en.value ().get <Attribute>().dyn_cast <IntegerAttr>().getInt ()) ;
230+ paddedShape[shapeIdx++] =
231+ en.value ().get <Attribute>().dyn_cast <IntegerAttr>().getInt ();
216232 continue ;
217233 }
218234 // Otherwise, try to compute a constant upper bound for the size value.
@@ -222,24 +238,21 @@ static LogicalResult padOperandToSmallestStaticBoundingBox(
222238 LLVM_DEBUG (DBGS () << " No constant bounding box can be found for padding" );
223239 return failure ();
224240 }
225- staticSizes. push_back ( upperBound.getValue () );
241+ paddedShape[shapeIdx++] = upperBound.getValue ();
226242 }
227- assert (staticSizes. size () == shape.size () &&
243+ assert (shapeIdx == static_cast < int64_t >( shape.size () ) &&
228244 " expect the dynamic and static ranks to match" );
229245
230- // Pad the operand to the bounding box defined by `staticSizes`.
231- auto staticTensorType = RankedTensorType::get (
232- staticSizes, getElementTypeOrSelf (opOperand->get ()));
233- bool nofold = opOperand->getOperandNumber () < packPaddings.size ()
234- ? packPaddings[opOperand->getOperandNumber ()]
235- : false ;
236- result = makeComposedPadHighOp (b, opToPad->getLoc (), staticTensorType,
237- opOperand->get (), paddingValue, nofold);
238- return success ();
246+ // Pad the operand to the bounding box defined by `paddedShape`.
247+ auto paddedTensorType = RankedTensorType::get (
248+ paddedShape, getElementTypeOrSelf (opOperand->get ()));
249+ return makeComposedPadHighOp (b, opToPad->getLoc (), paddedTensorType,
250+ opOperand->get (), paddingValue, nofold);
239251}
240252
241253FailureOr<SmallVector<Value>>
242254linalg::rewriteAsPaddedOp (OpBuilder &b, LinalgOp opToPad,
255+ ArrayRef<int64_t > paddingDimensions,
243256 ArrayRef<Attribute> paddingValues,
244257 ArrayRef<bool > packPaddings, LinalgOp &paddedOp) {
245258 Location loc = opToPad->getLoc ();
@@ -255,13 +268,12 @@ linalg::rewriteAsPaddedOp(OpBuilder &b, LinalgOp opToPad,
255268 SmallVector<Value> newOperands;
256269 newOperands.reserve (opToPad.getNumInputsAndOutputs ());
257270 for (OpOperand *opOperand : opToPad.getInputAndOutputOperands ()) {
258- Value paddedOperand;
259- // If padding was requested but the shape cannot be bounded statically then
260- // the pattern fails to apply.
261- if (failed (padOperandToSmallestStaticBoundingBox (
262- b, opToPad, opOperand, paddingValues, packPaddings, paddedOperand)))
271+ FailureOr<Value> paddedOperand = padOperandToSmallestStaticBoundingBox (
272+ b, opToPad, opOperand, paddingDimensions, paddingValues, packPaddings);
273+ // Exit if `paddingDimensions` cannot be bounded statically.
274+ if (failed (paddedOperand))
263275 return failure ();
264- newOperands.push_back (paddedOperand ? paddedOperand : opOperand-> get () );
276+ newOperands.push_back (* paddedOperand);
265277 }
266278
267279 SmallVector<SmallVector<Value>> reifiedResultShapes;
@@ -502,19 +514,25 @@ mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite(
502514 // Pad the operation.
503515 LinalgOp paddedOp;
504516 FailureOr<SmallVector<Value>> newResults =
505- rewriteAsPaddedOp (rewriter, linalgOp, options.paddingValues ,
506- options.packPaddings , paddedOp);
517+ rewriteAsPaddedOp (rewriter, linalgOp, options.paddingDimensions ,
518+ options.paddingValues , options. packPaddings , paddedOp);
507519 if (failed (newResults))
508520 return failure ();
509521
510522 // Hoist the padding.
511523 for (const auto &en : enumerate(options.hoistPaddings )) {
512524 if (static_cast <int64_t >(en.index ()) >= paddedOp.getNumInputsAndOutputs ())
513525 break ;
514- OpOperand & opOperand = paddedOp->getOpOperand (en.index ());
515- auto padOp = opOperand. get ().getDefiningOp <tensor::PadOp>();
526+ OpOperand * opOperand = & paddedOp->getOpOperand (en.index ());
527+ auto padOp = opOperand-> get ().getDefiningOp <tensor::PadOp>();
516528 if (!padOp || en.value () == 0 )
517529 continue ;
530+
531+ // Fail hoisting if the operand shape is not fully static.
532+ if (llvm::any_of (paddedOp.getShape (opOperand),
533+ [](int64_t size) { return ShapedType::isDynamic (size); }))
534+ return failure ();
535+
518536 tensor::PadOp hoistedOp;
519537 SmallVector<GenericOp> transposeOps;
520538 SmallVector<int64_t > transposeVector =
0 commit comments