@@ -1391,24 +1391,28 @@ Stmt LowererImplImperative::lowerForallPosition(Forall forall, Iterator iterator
13911391 endBound = endBounds[1 ];
13921392 }
13931393
1394- LoopKind kind = LoopKind::Serial;
1395- if (forall.getParallelUnit () == ParallelUnit::CPUVector && !ignoreVectorize) {
1396- kind = LoopKind::Vectorized;
1397- }
1398- else if (forall.getParallelUnit () != ParallelUnit::NotParallel
1399- && forall.getOutputRaceStrategy () != OutputRaceStrategy::ParallelReduction && !ignoreVectorize) {
1400- kind = LoopKind::Runtime;
1394+ Stmt loop = Block::make (strideGuard, declareCoordinate, boundsGuard, body);
1395+ if (iterator.isBranchless () && iterator.isCompact () &&
1396+ (iterator.getParent ().isRoot () || iterator.getParent ().isUnique ())) {
1397+ loop = Block::make (VarDecl::make (iterator.getPosVar (), startBound), loop);
1398+ } else {
1399+ LoopKind kind = LoopKind::Serial;
1400+ if (forall.getParallelUnit () == ParallelUnit::CPUVector && !ignoreVectorize) {
1401+ kind = LoopKind::Vectorized;
1402+ }
1403+ else if (forall.getParallelUnit () != ParallelUnit::NotParallel &&
1404+ forall.getOutputRaceStrategy () != OutputRaceStrategy::ParallelReduction &&
1405+ !ignoreVectorize) {
1406+ kind = LoopKind::Runtime;
1407+ }
1408+
1409+ loop = For::make (iterator.getPosVar (), startBound, endBound, 1 , loop, kind,
1410+ ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit (),
1411+ ignoreVectorize ? 0 : forall.getUnrollFactor ());
14011412 }
14021413
14031414 // Loop with preamble and postamble
1404- return Block::blanks (
1405- boundsCompute,
1406- For::make (iterator.getPosVar (), startBound, endBound, 1 ,
1407- Block::make (strideGuard, declareCoordinate, boundsGuard, body),
1408- kind,
1409- ignoreVectorize ? ParallelUnit::NotParallel : forall.getParallelUnit (), ignoreVectorize ? 0 : forall.getUnrollFactor ()),
1410- posAppend);
1411-
1415+ return Block::blanks (boundsCompute, loop, posAppend);
14121416}
14131417
14141418Stmt LowererImplImperative::lowerForallFusedPosition (Forall forall, Iterator iterator,
0 commit comments