@@ -10553,6 +10553,82 @@ class DecomposeAtenNllLossForwardOp
1055310553};
1055410554} // namespace
1055510555
10556+ namespace {
10557+ class DecomposeAtenPoissonNllLossOp
10558+ : public OpRewritePattern<AtenPoissonNllLossOp> {
10559+ public:
10560+ using OpRewritePattern::OpRewritePattern;
10561+ LogicalResult matchAndRewrite(AtenPoissonNllLossOp op,
10562+ PatternRewriter &rewriter) const override {
10563+ Location loc = op.getLoc();
10564+ Value input = op.getInput();
10565+ Value target = op.getTarget();
10566+ Value logInput = op.getLogInput();
10567+ Value full = op.getFull();
10568+ Value reduction = op.getReduction();
10569+ Value eps = op.getEps();
10570+
10571+ bool logInVal, fullVal;
10572+ if (!matchPattern(logInput, m_TorchConstantBool(&logInVal)))
10573+ return rewriter.notifyMatchFailure(
10574+ op, "expected logInput argument to be constant bool");
10575+ if (!matchPattern(full, m_TorchConstantBool(&fullVal)))
10576+ return rewriter.notifyMatchFailure(
10577+ op, "expected full argument to be constant bool");
10578+
10579+ int64_t reductionInt;
10580+ if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt)))
10581+ return rewriter.notifyMatchFailure(op, "expected constant reduction");
10582+
10583+ double epsFloat;
10584+ if (!matchPattern(eps, m_TorchConstantFloat(&epsFloat))) {
10585+ return rewriter.notifyMatchFailure(op, "expected constant eps");
10586+ }
10587+ // TODO: add support for full=true (Stirling approximation)
10588+ if (fullVal)
10589+ return rewriter.notifyMatchFailure(
10590+ op, "Unimplemented: full loss computation is not supported");
10591+
10592+ Value one =
10593+ rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
10594+ Value epsConst = rewriter.create<ConstantFloatOp>(
10595+ loc, rewriter.getF64FloatAttr(epsFloat));
10596+
10597+ Value safeInput = rewriter.create<AtenAddScalarOp>(loc, input.getType(),
10598+ input, epsConst, one);
10599+
10600+ Value loss;
10601+ if (logInVal) {
10602+ Value expIn = rewriter.create<AtenExpOp>(loc, input.getType(), input);
10603+ Value targetMulInput =
10604+ rewriter.create<AtenMulTensorOp>(loc, input.getType(), target, input);
10605+ loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), expIn,
10606+ targetMulInput, one);
10607+ } else {
10608+ Value logSafeInput =
10609+ rewriter.create<AtenLogOp>(loc, input.getType(), safeInput);
10610+ Value targetMulLog = rewriter.create<AtenMulTensorOp>(
10611+ loc, input.getType(), target, logSafeInput);
10612+ loss = rewriter.create<AtenSubTensorOp>(loc, input.getType(), input,
10613+ targetMulLog, one);
10614+ }
10615+
10616+ Value result = loss;
10617+ if (reductionInt == 1) {
10618+ // Case 1: Mean Reduction
10619+ result = rewriter.create<AtenMeanOp>(
10620+ loc, op.getType(), loss, rewriter.create<ConstantNoneOp>(loc));
10621+ } else if (reductionInt == 2) {
10622+ // Case 2: Sum Reduction
10623+ result = rewriter.create<AtenSumOp>(loc, op.getType(), loss,
10624+ rewriter.create<ConstantNoneOp>(loc));
10625+ }
10626+ rewriter.replaceOp(op, result);
10627+ return success();
10628+ }
10629+ };
10630+ } // namespace
10631+
1055610632namespace {
1055710633class DecomposeAtenBinaryCrossEntropyWithLogitsOp
1055810634 : public OpRewritePattern<AtenBinaryCrossEntropyWithLogitsOp> {
@@ -12467,6 +12543,7 @@ class DecomposeComplexOpsPass
1246712543 addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
1246812544 addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
1246912545 addPatternIfTargetOpIsIllegal<DecomposeAtenNllLossForwardOp>(patterns);
12546+ addPatternIfTargetOpIsIllegal<DecomposeAtenPoissonNllLossOp>(patterns);
1247012547 addPatternIfTargetOpIsIllegal<DecomposeAtenBinaryCrossEntropyWithLogitsOp>(
1247112548 patterns);
1247212549 addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
0 commit comments