@@ -11304,6 +11304,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
1130411304};
1130511305} // namespace
1130611306
11307+ namespace {
11308+ // Decomposed aten.heaviside op into
11309+ // using aten.eq, aten.lt, aten.logical_or, aten.where
11310+ // Heaviside(x, y) returns
11311+ // 0 if x < 0
11312+ // y if x == 0
11313+ // 1 if x > 0
11314+ class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
11315+ public:
11316+ using OpRewritePattern::OpRewritePattern;
11317+ LogicalResult matchAndRewrite(AtenHeavisideOp op,
11318+ PatternRewriter &rewriter) const override {
11319+ auto input = op.getSelf();
11320+ auto value = op.getValues();
11321+ auto loc = op.getLoc();
11322+ auto inputTy = dyn_cast<BaseTensorType>(input.getType());
11323+ if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes())
11324+ return rewriter.notifyMatchFailure(op, "input must have dtype and size.");
11325+
11326+ auto valueTy = dyn_cast<BaseTensorType>(value.getType());
11327+ if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes())
11328+ return rewriter.notifyMatchFailure(op, "value must have dtype and size.");
11329+ auto resultTy = dyn_cast<BaseTensorType>(op.getType());
11330+ SmallVector<int64_t> broadcastShape;
11331+ SmallVector<Value> broadcastShapeValue;
11332+ computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
11333+ broadcastShapeValue);
11334+
11335+ auto broadcastType = ValueTensorType::get(
11336+ op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
11337+ auto boolBroadcastType = ValueTensorType::get(
11338+ op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
11339+ Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
11340+ loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
11341+ broadcastShapeValue);
11342+ auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
11343+ loc, broadcastType, input, indexBroadcastShapeTorchList);
11344+ auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
11345+ loc, broadcastType, value, indexBroadcastShapeTorchList);
11346+
11347+ Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
11348+ resultTy.getDtype());
11349+ Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
11350+ resultTy.getDtype());
11351+ // Compute mask: input == 0
11352+ auto inputEqZero = rewriter
11353+ .create<AtenEqScalarOp>(loc, boolBroadcastType,
11354+ inputBroadcasted, zero)
11355+ ->getResult(0);
11356+ // Compute mask: input < 0
11357+ auto inputLtZero = rewriter.create<AtenLtScalarOp>(loc, boolBroadcastType,
11358+ inputBroadcasted, zero);
11359+ // Compute mask: isnan(input)
11360+ auto isNan =
11361+ rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
11362+ // Combine: input < 0 || isnan(input)
11363+ auto inputNegativeOrNan = rewriter.create<AtenLogicalOrOp>(
11364+ loc, boolBroadcastType, inputLtZero, isNan);
11365+ // Select 0 if input < 0 or input is nan, else 1
11366+ auto zerosOrOnes = rewriter.create<AtenWhereScalarOp>(
11367+ loc, resultTy, inputNegativeOrNan, zero, one);
11368+ // Final result: if input == 0, take from valueBroadcasted, else take from
11369+ // zerosOrOnes
11370+ rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, inputEqZero,
11371+ valueBroadcasted, zerosOrOnes);
11372+ return success();
11373+ }
11374+ };
11375+ } // namespace
11376+
1130711377namespace {
1130811378// Unconditionally decompose `torch.type_as` into `prim.dtype` +
1130911379// `torch.to.dtype`.
@@ -12528,6 +12598,7 @@ class DecomposeComplexOpsPass
1252812598 DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
1252912599 addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
1253012600 addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
12601+ addPatternIfTargetOpIsIllegal<DecomposeAtenHeaviside>(patterns);
1253112602 addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
1253212603 addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
1253312604 addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
0 commit comments