@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
244244 std::unique_ptr<SparseIterator> it =
245245 iterSpace.extractIterator (rewriter, loc);
246246
247- if (it->iteratableByFor ()) {
248- auto [lo, hi] = it->genForCond (rewriter, loc);
249- Value step = constantIndex (rewriter, loc, 1 );
250- SmallVector<Value> ivs;
251- for (ValueRange inits : adaptor.getInitArgs ())
252- llvm::append_range (ivs, inits);
253- scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, ivs);
254-
255- Block *loopBody = op.getBody ();
256- OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
257- if (failed (typeConverter->convertSignatureArgs (
258- loopBody->getArgumentTypes (), bodyTypeMapping)))
259- return failure ();
260- rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
261-
262- rewriter.eraseBlock (forOp.getBody ());
263- Region &dstRegion = forOp.getRegion ();
264- rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
265-
266- auto yieldOp =
267- llvm::cast<sparse_tensor::YieldOp>(forOp.getBody ()->getTerminator ());
268-
269- rewriter.setInsertionPointToEnd (forOp.getBody ());
270- // replace sparse_tensor.yield with scf.yield.
271- rewriter.create <scf::YieldOp>(loc, yieldOp.getResults ());
272- rewriter.eraseOp (yieldOp);
273-
274- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
275- rewriter.replaceOp (op, forOp.getResults (), resultMapping);
276- } else {
277- SmallVector<Value> ivs;
278- // TODO: put iterator at the end of argument list to be consistent with
279- // coiterate operation.
280- llvm::append_range (ivs, it->getCursor ());
281- for (ValueRange inits : adaptor.getInitArgs ())
282- llvm::append_range (ivs, inits);
283-
284- assert (llvm::all_of (ivs, [](Value v) { return v != nullptr ; }));
285-
286- TypeRange types = ValueRange (ivs).getTypes ();
287- auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
288- SmallVector<Location> l (types.size (), op.getIterator ().getLoc ());
289-
290- // Generates loop conditions.
291- Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
292- rewriter.setInsertionPointToStart (before);
293- ValueRange bArgs = before->getArguments ();
294- auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
295- assert (remArgs.size () == adaptor.getInitArgs ().size ());
296- rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
297-
298- // Generates loop body.
299- Block *loopBody = op.getBody ();
300- OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
301- if (failed (typeConverter->convertSignatureArgs (
302- loopBody->getArgumentTypes (), bodyTypeMapping)))
303- return failure ();
304- rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
305-
306- Region &dstRegion = whileOp.getAfter ();
307- // TODO: handle uses of coordinate!
308- rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
309- ValueRange aArgs = whileOp.getAfterArguments ();
310- auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
311- whileOp.getAfterBody ()->getTerminator ());
312-
313- rewriter.setInsertionPointToEnd (whileOp.getAfterBody ());
247+ SmallVector<Value> ivs;
248+ for (ValueRange inits : adaptor.getInitArgs ())
249+ llvm::append_range (ivs, inits);
250+
251+ // Type conversion on iterate op block.
252+ OneToNTypeMapping blockTypeMapping (op.getBody ()->getArgumentTypes ());
253+ if (failed (typeConverter->convertSignatureArgs (
254+ op.getBody ()->getArgumentTypes (), blockTypeMapping)))
255+ return rewriter.notifyMatchFailure (
256+ op, " failed to convert iterate region argurment types" );
257+ rewriter.applySignatureConversion (op.getBody (), blockTypeMapping);
258+
259+ Block *block = op.getBody ();
260+ ValueRange ret = genLoopWithIterator (
261+ rewriter, loc, it.get (), ivs, /* iterFirst=*/ true ,
262+ [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
263+ SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
264+ SmallVector<Value> blockArgs (it->getCursor ());
265+ // TODO: Also appends coordinates if used.
266+ // blockArgs.push_back(it->deref(rewriter, loc));
267+ llvm::append_range (blockArgs, reduc);
268+
269+ Block *dstBlock = &loopBody.getBlocks ().front ();
270+ rewriter.inlineBlockBefore (block, dstBlock, dstBlock->end (),
271+ blockArgs);
272+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back ());
273+ // We can not use ValueRange as the operation holding the values will
274+ // be destoryed.
275+ SmallVector<Value> result (yield.getResults ());
276+ rewriter.eraseOp (yield);
277+ return result;
278+ });
314279
315- aArgs = it->linkNewScope (aArgs);
316- ValueRange nx = it->forward (rewriter, loc);
317- SmallVector<Value> yields;
318- llvm::append_range (yields, nx);
319- llvm::append_range (yields, yieldOp.getResults ());
320-
321- // replace sparse_tensor.yield with scf.yield.
322- rewriter.eraseOp (yieldOp);
323- rewriter.create <scf::YieldOp>(loc, yields);
324- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
325- rewriter.replaceOp (
326- op, whileOp.getResults ().drop_front (it->getCursor ().size ()),
327- resultMapping);
328- }
280+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
281+ rewriter.replaceOp (op, ret, resultMapping);
329282 return success ();
330283 }
331284};
@@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
366319 Block *block = ®ion.getBlocks ().front ();
367320 OneToNTypeMapping blockTypeMapping (block->getArgumentTypes ());
368321 if (failed (typeConverter->convertSignatureArgs (block->getArgumentTypes (),
369- blockTypeMapping)))
322+ blockTypeMapping))) {
370323 return rewriter.notifyMatchFailure (
371324 op, " failed to convert coiterate region argurment types" );
325+ }
372326
373327 rewriter.applySignatureConversion (block, blockTypeMapping);
374328 }
0 commit comments