Skip to content

Commit 1da0692

Browse files
committed
fix TorchOps
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
1 parent 5c9ce38 commit 1da0692

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -372,21 +372,21 @@ LogicalResult ClassTypeOp::verify() {
372372
// PrimLoopOp
373373
//===----------------------------------------------------------------------===//
374374

375-
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) {
376-
assert(point == getRegion());
375+
OperandRange PrimLoopOp::getEntrySuccessorOperands(RegionSuccessor successor) {
376+
assert(successor.getSuccessor() == &getRegion());
377377
return getIterArgsInit();
378378
}
379379

380380
void PrimLoopOp::getSuccessorRegions(
381381
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
382382
Region &region = getRegion();
383-
if (!point.getRegionOrNull()) {
383+
if (!point.getTerminatorPredecessorOrNull()) {
384384
regions.emplace_back(&region, region.getArguments().slice(1));
385385
return;
386386
}
387-
assert(point == region);
387+
assert(point.getTerminatorPredecessorOrNull()->getParentRegion() == &region);
388388
regions.emplace_back(&region, region.getArguments().slice(1));
389-
regions.emplace_back(getResults());
389+
regions.emplace_back(getOperation(), getResults());
390390
}
391391

392392
bool PrimLoopOp::isForLike() {
@@ -399,7 +399,7 @@ bool PrimLoopOp::isForLike() {
399399
//===----------------------------------------------------------------------===//
400400

401401
MutableOperandRange
402-
PrimLoopConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
402+
PrimLoopConditionOp::getMutableSuccessorOperands(RegionSuccessor successor) {
403403
// Pass all operands except the condition to the successor which is the
404404
// parent loop op.
405405
return getIterArgsMutable();
@@ -451,8 +451,8 @@ void PrimIfOp::print(OpAsmPrinter &p) {
451451
void PrimIfOp::getSuccessorRegions(RegionBranchPoint point,
452452
SmallVectorImpl<RegionSuccessor> &regions) {
453453
// The `then` and the `else` region branch back to the parent operation.
454-
if (point.getRegionOrNull()) {
455-
regions.push_back(RegionSuccessor(getResults()));
454+
if (point.getTerminatorPredecessorOrNull()) {
455+
regions.push_back(RegionSuccessor(getOperation(), getResults()));
456456
return;
457457
}
458458

@@ -5245,17 +5245,18 @@ template <typename CalculateOp>
52455245
static void
52465246
getSuccessorRegionsForCalculateOp(CalculateOp op, RegionBranchPoint point,
52475247
SmallVectorImpl<RegionSuccessor> &regions) {
5248-
if (!point.getRegionOrNull()) {
5248+
if (!point.getTerminatorPredecessorOrNull()) {
52495249
// First thing the op does is branch into the calculation.
52505250
regions.emplace_back(&op.getCalculation());
52515251
return;
52525252
}
5253-
if (point == op.getBody()) {
5253+
Region *region = point.getTerminatorPredecessorOrNull()->getParentRegion();
5254+
if (region == &op.getBody()) {
52545255
// Body returns control to the outer op, passing through results.
5255-
regions.emplace_back(op.getResults());
5256+
regions.emplace_back(op.getOperation(), op.getResults());
52565257
return;
52575258
}
5258-
assert(point == op.getCalculation());
5259+
assert(region == &op.getCalculation());
52595260
// Calculation branches to the body.
52605261
regions.emplace_back(&op.getBody());
52615262
}
@@ -5279,7 +5280,7 @@ void DtypeCalculateOp::getSuccessorRegions(
52795280
//===----------------------------------------------------------------------===//
52805281

52815282
MutableOperandRange ShapeCalculateYieldShapesOp::getMutableSuccessorOperands(
5282-
RegionBranchPoint point) {
5283+
RegionSuccessor successor) {
52835284
// The shape operands don't get forwarded to the body.
52845285
// MutableOperandRange always has an owning operation, even if empty, so
52855286
// create a 0-length range.
@@ -5770,7 +5771,7 @@ LogicalResult AtenKthvalueOp::verify() {
57705771
//===----------------------------------------------------------------------===//
57715772

57725773
MutableOperandRange DtypeCalculateYieldDtypesOp::getMutableSuccessorOperands(
5773-
RegionBranchPoint point) {
5774+
RegionSuccessor successor) {
57745775
// The dtype operands don't get forwarded to the body.
57755776
// MutableOperandRange always has an owning operation, even if empty, so
57765777
// create a 0-length range.

0 commit comments

Comments
 (0)