@@ -211,25 +211,14 @@ static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
211211 return decl;
212212}
213213
214- // / Returns an LLVM pointer type with the given element type, or an opaque
215- // / pointer if 'useOpaquePointers' is true.
216- static LLVM::LLVMPointerType getPointerType (Type elementType,
217- bool useOpaquePointers) {
218- if (useOpaquePointers)
219- return LLVM::LLVMPointerType::get (elementType.getContext ());
220- return LLVM::LLVMPointerType::get (elementType);
221- }
222-
223214// / Adds an atomic reduction combiner to the given OpenMP reduction declaration
224215// / using llvm.atomicrmw of the given kind.
225216static omp::ReductionDeclareOp addAtomicRMW (OpBuilder &builder,
226217 LLVM::AtomicBinOp atomicKind,
227218 omp::ReductionDeclareOp decl,
228- scf::ReduceOp reduce,
229- bool useOpaquePointers) {
219+ scf::ReduceOp reduce) {
230220 OpBuilder::InsertionGuard guard (builder);
231- Type type = reduce.getOperand ().getType ();
232- Type ptrType = getPointerType (type, useOpaquePointers);
221+ auto ptrType = LLVM::LLVMPointerType::get (builder.getContext ());
233222 Location reduceOperandLoc = reduce.getOperand ().getLoc ();
234223 builder.createBlock (&decl.getAtomicReductionRegion (),
235224 decl.getAtomicReductionRegion ().end (), {ptrType, ptrType},
@@ -250,8 +239,7 @@ static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
250239// / the neutral value, necessary for the OpenMP declaration. If the reduction
251240// / cannot be recognized, returns null.
252241static omp::ReductionDeclareOp declareReduction (PatternRewriter &builder,
253- scf::ReduceOp reduce,
254- bool useOpaquePointers) {
242+ scf::ReduceOp reduce) {
255243 Operation *container = SymbolTable::getNearestSymbolTable (reduce);
256244 SymbolTable symbolTable (container);
257245
@@ -272,34 +260,29 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
272260 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
273261 omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
274262 builder.getFloatAttr (type, 0.0 ));
275- return addAtomicRMW (builder, LLVM::AtomicBinOp::fadd, decl, reduce,
276- useOpaquePointers);
263+ return addAtomicRMW (builder, LLVM::AtomicBinOp::fadd, decl, reduce);
277264 }
278265 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
279266 omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
280267 builder.getIntegerAttr (type, 0 ));
281- return addAtomicRMW (builder, LLVM::AtomicBinOp::add, decl, reduce,
282- useOpaquePointers);
268+ return addAtomicRMW (builder, LLVM::AtomicBinOp::add, decl, reduce);
283269 }
284270 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
285271 omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
286272 builder.getIntegerAttr (type, 0 ));
287- return addAtomicRMW (builder, LLVM::AtomicBinOp::_or, decl, reduce,
288- useOpaquePointers);
273+ return addAtomicRMW (builder, LLVM::AtomicBinOp::_or, decl, reduce);
289274 }
290275 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
291276 omp::ReductionDeclareOp decl = createDecl (builder, symbolTable, reduce,
292277 builder.getIntegerAttr (type, 0 ));
293- return addAtomicRMW (builder, LLVM::AtomicBinOp::_xor, decl, reduce,
294- useOpaquePointers);
278+ return addAtomicRMW (builder, LLVM::AtomicBinOp::_xor, decl, reduce);
295279 }
296280 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
297281 omp::ReductionDeclareOp decl = createDecl (
298282 builder, symbolTable, reduce,
299283 builder.getIntegerAttr (
300284 type, llvm::APInt::getAllOnes (type.getIntOrFloatBitWidth ())));
301- return addAtomicRMW (builder, LLVM::AtomicBinOp::_and, decl, reduce,
302- useOpaquePointers);
285+ return addAtomicRMW (builder, LLVM::AtomicBinOp::_and, decl, reduce);
303286 }
304287
305288 // Match simple binary reductions that cannot be expressed with atomicrmw.
@@ -335,7 +318,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
335318 builder, symbolTable, reduce, minMaxValueForSignedInt (type, !isMin));
336319 return addAtomicRMW (builder,
337320 isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
338- decl, reduce, useOpaquePointers );
321+ decl, reduce);
339322 }
340323 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
341324 reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -347,7 +330,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
347330 builder, symbolTable, reduce, minMaxValueForUnsignedInt (type, !isMin));
348331 return addAtomicRMW (
349332 builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
350- decl, reduce, useOpaquePointers );
333+ decl, reduce);
351334 }
352335
353336 return nullptr ;
@@ -357,11 +340,8 @@ namespace {
357340
358341struct ParallelOpLowering : public OpRewritePattern <scf::ParallelOp> {
359342
360- bool useOpaquePointers;
361-
362- ParallelOpLowering (MLIRContext *context, bool useOpaquePointers)
363- : OpRewritePattern<scf::ParallelOp>(context),
364- useOpaquePointers (useOpaquePointers) {}
343+ ParallelOpLowering (MLIRContext *context)
344+ : OpRewritePattern<scf::ParallelOp>(context) {}
365345
366346 LogicalResult matchAndRewrite (scf::ParallelOp parallelOp,
367347 PatternRewriter &rewriter) const override {
@@ -370,8 +350,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
370350 // declaration and use it instead of redeclaring.
371351 SmallVector<Attribute> reductionDeclSymbols;
372352 for (auto reduce : parallelOp.getOps <scf::ReduceOp>()) {
373- omp::ReductionDeclareOp decl =
374- declareReduction (rewriter, reduce, useOpaquePointers);
353+ omp::ReductionDeclareOp decl = declareReduction (rewriter, reduce);
375354 if (!decl)
376355 return failure ();
377356 reductionDeclSymbols.push_back (
@@ -385,14 +364,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
385364 loc, rewriter.getIntegerType (64 ), rewriter.getI64IntegerAttr (1 ));
386365 SmallVector<Value> reductionVariables;
387366 reductionVariables.reserve (parallelOp.getNumReductions ());
367+ auto ptrType = LLVM::LLVMPointerType::get (parallelOp.getContext ());
388368 for (Value init : parallelOp.getInitVals ()) {
389369 assert ((LLVM::isCompatibleType (init.getType ()) ||
390370 isa<LLVM::PointerElementTypeInterface>(init.getType ())) &&
391371 " cannot create a reduction variable if the type is not an LLVM "
392372 " pointer element" );
393- Value storage = rewriter.create <LLVM::AllocaOp>(
394- loc, getPointerType (init.getType (), useOpaquePointers),
395- init.getType (), one, 0 );
373+ Value storage =
374+ rewriter.create <LLVM::AllocaOp>(loc, ptrType, init.getType (), one, 0 );
396375 rewriter.create <LLVM::StoreOp>(loc, init, storage);
397376 reductionVariables.push_back (storage);
398377 }
@@ -464,14 +443,14 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
464443};
465444
466445// / Applies the conversion patterns in the given function.
467- static LogicalResult applyPatterns (ModuleOp module , bool useOpaquePointers ) {
446+ static LogicalResult applyPatterns (ModuleOp module ) {
468447 ConversionTarget target (*module .getContext ());
469448 target.addIllegalOp <scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
470449 target.addLegalDialect <omp::OpenMPDialect, LLVM::LLVMDialect,
471450 memref::MemRefDialect>();
472451
473452 RewritePatternSet patterns (module .getContext ());
474- patterns.add <ParallelOpLowering>(module .getContext (), useOpaquePointers );
453+ patterns.add <ParallelOpLowering>(module .getContext ());
475454 FrozenRewritePatternSet frozen (std::move (patterns));
476455 return applyPartialConversion (module , target, frozen);
477456}
@@ -484,7 +463,7 @@ struct SCFToOpenMPPass
484463
485464 // / Pass entry point.
486465 void runOnOperation () override {
487- if (failed (applyPatterns (getOperation (), useOpaquePointers )))
466+ if (failed (applyPatterns (getOperation ())))
488467 signalPassFailure ();
489468 }
490469};
0 commit comments