@@ -117,7 +117,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
117117 constAttr);
118118 }
119119
120- if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
120+ if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp >(op)) {
121121 auto constAttr =
122122 DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 0 )});
123123 return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
@@ -169,7 +169,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
169169 } else if (isa<AtenAllOp, AtenAllDimOp>(op)) {
170170 result = rewriter.create <stablehlo::AndOp>(
171171 op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
172- } else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
172+ } else if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp >(op)) {
173173 result = rewriter.create <stablehlo::OrOp>(
174174 op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
175175 } else if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
@@ -610,6 +610,82 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
610610};
611611} // namespace
612612
613+ // AtenAnyDimsOp
614+ namespace {
615+ template <>
616+ LogicalResult ConvertAtenReductionOp<AtenAnyDimsOp>::matchAndRewrite(
617+ AtenAnyDimsOp op, OpAdaptor adaptor,
618+ ConversionPatternRewriter &rewriter) const {
619+ Value input = adaptor.getSelf ();
620+ auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
621+ auto outTy =
622+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
623+ if (!inputTy) {
624+ return rewriter.notifyMatchFailure (
625+ op, " only Tensor types supported in StableHLO" );
626+ }
627+ if (inputTy.getElementType () != outTy.getElementType ()) {
628+ // Use output element type as computation type.
629+ auto dstElemTy = outTy.getElementType ();
630+ input =
631+ rewriter.create <stablehlo::ConvertOp>(op->getLoc (), input, dstElemTy);
632+ inputTy = dyn_cast<RankedTensorType>(input.getType ());
633+ }
634+ auto inputElemTy = inputTy.getElementType ();
635+ if (!inputElemTy.isIntOrFloat ()) {
636+ return op.emitError (
637+ " Only floating-point or integer datatype legalization supported" );
638+ }
639+
640+ SmallVector<int64_t > inputDims;
641+ SmallVector<int64_t > dims;
642+ if (!matchPattern (op.getDim (), m_TorchListOfConstantInts (inputDims))) {
643+ return rewriter.notifyMatchFailure (
644+ op, " non-const integer `dim` is not supported" );
645+ }
646+ if (inputDims.size () == 0 ) {
647+ rewriter.replaceOp (op, input);
648+ return success ();
649+ }
650+ for (auto d : inputDims) {
651+ d = toPositiveDim (d, inputTy.getRank ());
652+ // Drop invalid dims
653+ if (isValidDim (d, inputTy.getRank ())) {
654+ dims.push_back (d);
655+ }
656+ }
657+ llvm::sort (dims.begin (), dims.end ());
658+
659+ SmallVector<int64_t > reduceResultShape =
660+ getReduceOutputShape (inputTy.getShape (), dims);
661+
662+ bool keepDim = false ;
663+ if (!matchPattern (op.getKeepdim (), m_TorchConstantBool (&keepDim))) {
664+ return rewriter.notifyMatchFailure (op, " non-bool keepdim unsupported" );
665+ }
666+
667+ Value reduceResult = createReduceOpWithSingleRegionOp (
668+ op, input,
669+ RankedTensorType::get (reduceResultShape, outTy.getElementType ()), dims,
670+ rewriter);
671+ if (!reduceResult) {
672+ return op->emitError (" createReduceOpWithSingleRegionOp return nullptr" );
673+ }
674+
675+ if (keepDim) {
676+ auto outShapeInfo = hlo::getDimIndexOfTensor (rewriter, op, input);
677+ if (failed (outShapeInfo)) {
678+ return rewriter.notifyMatchFailure (
679+ op, " failed to get dimension sizes of the input" );
680+ }
681+ reduceResult = reshapeReduceResultWhenKeepDim (
682+ rewriter, op->getLoc (), reduceResult, *outShapeInfo, outTy, dims);
683+ }
684+ rewriter.replaceOp (op, reduceResult);
685+ return success ();
686+ }
687+ } // namespace
688+
613689// AtenSumDimIntListOp
614690namespace {
615691template <>
@@ -928,6 +1004,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
9281004#define INSERT_ATEN_REDUCTION_OP_PATTERN (AtenOp ) \
9291005 target.addIllegalOp <AtenOp>(); \
9301006 patterns.add <ConvertAtenReductionOp<AtenOp>>(typeConverter, context, options)
1007+ INSERT_ATEN_REDUCTION_OP_PATTERN (AtenAnyDimsOp);
9311008 INSERT_ATEN_REDUCTION_OP_PATTERN (AtenSumDimIntListOp);
9321009 INSERT_ATEN_REDUCTION_OP_PATTERN (AtenFrobeniusNormDimOp);
9331010 INSERT_ATEN_REDUCTION_OP_PATTERN (AtenLinalgVectorNormOp);
0 commit comments