1111// ===----------------------------------------------------------------------===//
1212
1313#include " LowerToMLIRHelpers.h"
14+ #include " mlir/Analysis/DataLayoutAnalysis.h"
1415#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
1516#include " mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1617#include " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
3233#include " mlir/IR/Operation.h"
3334#include " mlir/IR/Region.h"
3435#include " mlir/IR/TypeRange.h"
36+ #include " mlir/IR/Types.h"
3537#include " mlir/IR/Value.h"
3638#include " mlir/IR/ValueRange.h"
39+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
3740#include " mlir/Pass/Pass.h"
3841#include " mlir/Pass/PassManager.h"
3942#include " mlir/Support/LogicalResult.h"
@@ -163,17 +166,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
163166 matchAndRewrite (cir::AllocaOp op, OpAdaptor adaptor,
164167 mlir::ConversionPatternRewriter &rewriter) const override {
165168
166- mlir::Type mlirType =
167- convertTypeForMemory (*getTypeConverter (), adaptor. getAllocaType () );
169+ mlir::Type allocaType = adaptor. getAllocaType ();
170+ mlir::Type mlirType = convertTypeForMemory (*getTypeConverter (), allocaType );
168171
169172 // FIXME: Some types can not be converted yet (e.g. struct)
170173 if (!mlirType)
171174 return mlir::LogicalResult::failure ();
172175
173176 auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
174- if (memreftype && mlir::isa<cir::ArrayType>(adaptor. getAllocaType ())) {
175- // if the type is an array,
176- // we don't need to wrap with memref .
177+ if (memreftype && ( mlir::isa<cir::ArrayType>(allocaType) ||
178+ mlir::isa<cir::RecordType>(allocaType))) {
179+ // Arrays and structs are already memref. No need to wrap another one .
177180 } else {
178181 memreftype = mlir::MemRefType::get ({}, mlirType);
179182 }
@@ -1240,6 +1243,36 @@ class CIRPtrStrideOpLowering
12401243 }
12411244};
12421245
1246+ class CIRGetMemberOpLowering
1247+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
1248+ public:
1249+ CIRGetMemberOpLowering (mlir::TypeConverter &converter, mlir::MLIRContext *ctx,
1250+ const mlir::DataLayout &layout)
1251+ : OpConversionPattern(converter, ctx), layout(layout) {}
1252+
1253+ mlir::LogicalResult
1254+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
1255+ mlir::ConversionPatternRewriter &rewriter) const override {
1256+ auto baseAddr = op.getAddr ();
1257+ auto structType =
1258+ mlir::cast<cir::RecordType>(baseAddr.getType ().getPointee ());
1259+ uint64_t byteOffset = structType.getElementOffset (layout, op.getIndex ());
1260+
1261+ auto fieldType = op.getResult ().getType ();
1262+ auto resultType = mlir::cast<mlir::MemRefType>(
1263+ getTypeConverter ()->convertType (fieldType));
1264+
1265+ mlir::Value offsetValue =
1266+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), byteOffset);
1267+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
1268+ op, resultType, adaptor.getAddr (), offsetValue, mlir::ValueRange{});
1269+ return mlir::success ();
1270+ }
1271+
1272+ private:
1273+ const mlir::DataLayout &layout;
1274+ };
1275+
12431276class CIRUnreachableOpLowering
12441277 : public mlir::OpConversionPattern<cir::UnreachableOp> {
12451278public:
@@ -1255,7 +1288,8 @@ class CIRUnreachableOpLowering
12551288};
12561289
12571290void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1258- mlir::TypeConverter &converter) {
1291+ mlir::TypeConverter &converter,
1292+ mlir::DataLayout layout) {
12591293 patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
12601294
12611295 patterns
@@ -1276,16 +1310,20 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
12761310 CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
12771311 CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering>(
12781312 converter, patterns.getContext ());
1313+
1314+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1315+ layout);
12791316}
12801317
1281- static mlir::TypeConverter prepareTypeConverter () {
1318+ static mlir::TypeConverter prepareTypeConverter (mlir::DataLayout layout ) {
12821319 mlir::TypeConverter converter;
12831320 converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
1284- auto ty = convertTypeForMemory (converter, type.getPointee ());
1321+ auto pointee = type.getPointee ();
1322+ auto ty = convertTypeForMemory (converter, pointee);
12851323 // FIXME: The pointee type might not be converted (e.g. struct)
12861324 if (!ty)
12871325 return nullptr ;
1288- if (isa<cir::ArrayType>(type. getPointee () ))
1326+ if (isa<cir::ArrayType>(pointee) || isa<cir::RecordType>(pointee ))
12891327 return ty;
12901328 return mlir::MemRefType::get ({}, ty);
12911329 });
@@ -1337,6 +1375,13 @@ static mlir::TypeConverter prepareTypeConverter() {
13371375 return nullptr ;
13381376 return mlir::MemRefType::get (shape, elementType);
13391377 });
1378+ converter.addConversion ([&](cir::RecordType type) -> mlir::Type {
1379+ // Reinterpret structs as raw bytes. Don't use tuples as they can't be put
1380+ // in memref.
1381+ auto size = type.getTypeSize (layout, {});
1382+ auto i8 = mlir::IntegerType::get (type.getContext (), /* width=*/ 8 );
1383+ return mlir::MemRefType::get (size.getFixedValue (), i8 );
1384+ });
13401385 converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
13411386 auto ty = converter.convertType (type.getEltType ());
13421387 return mlir::VectorType::get (type.getSize (), ty);
@@ -1347,13 +1392,15 @@ static mlir::TypeConverter prepareTypeConverter() {
13471392
13481393void ConvertCIRToMLIRPass::runOnOperation () {
13491394 auto module = getOperation ();
1395+ mlir::DataLayoutAnalysis layoutAnalysis (module );
1396+ const mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove (module );
13501397
1351- auto converter = prepareTypeConverter ();
1398+ auto converter = prepareTypeConverter (layout );
13521399
13531400 mlir::RewritePatternSet patterns (&getContext ());
13541401
13551402 populateCIRLoopToSCFConversionPatterns (patterns, converter);
1356- populateCIRToMLIRConversionPatterns (patterns, converter);
1403+ populateCIRToMLIRConversionPatterns (patterns, converter, layout );
13571404
13581405 mlir::ConversionTarget target (getContext ());
13591406 target.addLegalOp <mlir::ModuleOp>();
0 commit comments