|
14 | 14 | #include "../PassDetail.h" |
15 | 15 | #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" |
16 | 16 | #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" |
| 17 | +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" |
17 | 18 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
18 | 19 | #include "mlir/Dialect/StandardOps/IR/Ops.h" |
19 | 20 | #include "mlir/IR/Attributes.h" |
@@ -1793,31 +1794,6 @@ struct AllocLikeOpLowering : public ConvertToLLVMPattern { |
1793 | 1794 | return rewriter.create<LLVM::SubOp>(loc, bumped, mod); |
1794 | 1795 | } |
1795 | 1796 |
|
1796 | | - // Creates a call to an allocation function with params and casts the |
1797 | | - // resulting void pointer to ptrType. |
1798 | | - Value createAllocCall(Location loc, StringRef name, Type ptrType, |
1799 | | - ArrayRef<Value> params, ModuleOp module, |
1800 | | - ConversionPatternRewriter &rewriter) const { |
1801 | | - SmallVector<Type, 2> paramTypes; |
1802 | | - auto allocFuncOp = module.lookupSymbol<LLVM::LLVMFuncOp>(name); |
1803 | | - if (!allocFuncOp) { |
1804 | | - for (Value param : params) |
1805 | | - paramTypes.push_back(param.getType()); |
1806 | | - auto allocFuncType = |
1807 | | - LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); |
1808 | | - OpBuilder::InsertionGuard guard(rewriter); |
1809 | | - rewriter.setInsertionPointToStart(module.getBody()); |
1810 | | - allocFuncOp = rewriter.create<LLVM::LLVMFuncOp>(rewriter.getUnknownLoc(), |
1811 | | - name, allocFuncType); |
1812 | | - } |
1813 | | - auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); |
1814 | | - auto allocatedPtr = rewriter |
1815 | | - .create<LLVM::CallOp>(loc, getVoidPtrType(), |
1816 | | - allocFuncSymbol, params) |
1817 | | - .getResult(0); |
1818 | | - return rewriter.create<LLVM::BitcastOp>(loc, ptrType, allocatedPtr); |
1819 | | - } |
1820 | | - |
1821 | 1797 | /// Allocates the underlying buffer. Returns the allocated pointer and the |
1822 | 1798 | /// aligned pointer. |
1823 | 1799 | virtual std::tuple<Value, Value> |
@@ -1909,9 +1885,12 @@ struct AllocOpLowering : public AllocLikeOpLowering { |
1909 | 1885 | // Allocate the underlying buffer and store a pointer to it in the MemRef |
1910 | 1886 | // descriptor. |
1911 | 1887 | Type elementPtrType = this->getElementPtrType(memRefType); |
| 1888 | + auto allocFuncOp = LLVM::lookupOrCreateMallocFn( |
| 1889 | + allocOp->getParentOfType<ModuleOp>(), getIndexType()); |
| 1890 | + auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, |
| 1891 | + getVoidPtrType()); |
1912 | 1892 | Value allocatedPtr = |
1913 | | - createAllocCall(loc, "malloc", elementPtrType, {sizeBytes}, |
1914 | | - allocOp->getParentOfType<ModuleOp>(), rewriter); |
| 1893 | + rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); |
1915 | 1894 |
|
1916 | 1895 | Value alignedPtr = allocatedPtr; |
1917 | 1896 | if (alignment) { |
@@ -1991,9 +1970,13 @@ struct AlignedAllocOpLowering : public AllocLikeOpLowering { |
1991 | 1970 | sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); |
1992 | 1971 |
|
1993 | 1972 | Type elementPtrType = this->getElementPtrType(memRefType); |
1994 | | - Value allocatedPtr = createAllocCall( |
1995 | | - loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes}, |
1996 | | - allocOp->getParentOfType<ModuleOp>(), rewriter); |
| 1973 | + auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( |
| 1974 | + allocOp->getParentOfType<ModuleOp>(), getIndexType()); |
| 1975 | + auto results = |
| 1976 | + createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, |
| 1977 | + getVoidPtrType()); |
| 1978 | + Value allocatedPtr = |
| 1979 | + rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, results[0]); |
1997 | 1980 |
|
1998 | 1981 | return std::make_tuple(allocatedPtr, allocatedPtr); |
1999 | 1982 | } |
@@ -2056,31 +2039,17 @@ static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, |
2056 | 2039 |
|
2057 | 2040 | // Get frequently used types. |
2058 | 2041 | MLIRContext *context = builder.getContext(); |
2059 | | - auto voidType = LLVM::LLVMVoidType::get(context); |
2060 | 2042 | Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); |
2061 | 2043 | auto i1Type = IntegerType::get(context, 1); |
2062 | 2044 | Type indexType = typeConverter.getIndexType(); |
2063 | 2045 |
|
2064 | 2046 | // Find the malloc and free, or declare them if necessary. |
2065 | 2047 | auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); |
2066 | | - auto mallocFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("malloc"); |
2067 | | - if (!mallocFunc && toDynamic) { |
2068 | | - OpBuilder::InsertionGuard guard(builder); |
2069 | | - builder.setInsertionPointToStart(module.getBody()); |
2070 | | - mallocFunc = builder.create<LLVM::LLVMFuncOp>( |
2071 | | - builder.getUnknownLoc(), "malloc", |
2072 | | - LLVM::LLVMFunctionType::get(voidPtrType, llvm::makeArrayRef(indexType), |
2073 | | - /*isVarArg=*/false)); |
2074 | | - } |
2075 | | - auto freeFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("free"); |
2076 | | - if (!freeFunc && !toDynamic) { |
2077 | | - OpBuilder::InsertionGuard guard(builder); |
2078 | | - builder.setInsertionPointToStart(module.getBody()); |
2079 | | - freeFunc = builder.create<LLVM::LLVMFuncOp>( |
2080 | | - builder.getUnknownLoc(), "free", |
2081 | | - LLVM::LLVMFunctionType::get(voidType, llvm::makeArrayRef(voidPtrType), |
2082 | | - /*isVarArg=*/false)); |
2083 | | - } |
| 2048 | + LLVM::LLVMFuncOp freeFunc, mallocFunc; |
| 2049 | + if (toDynamic) |
| 2050 | + mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); |
| 2051 | + if (!toDynamic) |
| 2052 | + freeFunc = LLVM::lookupOrCreateFreeFn(module); |
2084 | 2053 |
|
2085 | 2054 | // Initialize shared constants. |
2086 | 2055 | Value zero = |
@@ -2217,17 +2186,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<DeallocOp> { |
2217 | 2186 | DeallocOp::Adaptor transformed(operands); |
2218 | 2187 |
|
2219 | 2188 | // Insert the `free` declaration if it is not already present. |
2220 | | - auto freeFunc = |
2221 | | - op->getParentOfType<ModuleOp>().lookupSymbol<LLVM::LLVMFuncOp>("free"); |
2222 | | - if (!freeFunc) { |
2223 | | - OpBuilder::InsertionGuard guard(rewriter); |
2224 | | - rewriter.setInsertionPointToStart( |
2225 | | - op->getParentOfType<ModuleOp>().getBody()); |
2226 | | - freeFunc = rewriter.create<LLVM::LLVMFuncOp>( |
2227 | | - rewriter.getUnknownLoc(), "free", |
2228 | | - LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType())); |
2229 | | - } |
2230 | | - |
| 2189 | + auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType<ModuleOp>()); |
2231 | 2190 | MemRefDescriptor memref(transformed.memref()); |
2232 | 2191 | Value casted = rewriter.create<LLVM::BitcastOp>( |
2233 | 2192 | op.getLoc(), getVoidPtrType(), |
|
0 commit comments