1111// ===----------------------------------------------------------------------===//
1212
1313#include " LowerToMLIRHelpers.h"
14+ #include " mlir/Analysis/DataLayoutAnalysis.h"
1415#include " mlir/Conversion/AffineToStandard/AffineToStandard.h"
1516#include " mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1617#include " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
3233#include " mlir/IR/Operation.h"
3334#include " mlir/IR/Region.h"
3435#include " mlir/IR/TypeRange.h"
36+ #include " mlir/IR/Types.h"
3537#include " mlir/IR/Value.h"
3638#include " mlir/IR/ValueRange.h"
39+ #include " mlir/Interfaces/DataLayoutInterfaces.h"
3740#include " mlir/Pass/Pass.h"
3841#include " mlir/Pass/PassManager.h"
3942#include " mlir/Support/LogicalResult.h"
@@ -163,17 +166,17 @@ class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
163166 matchAndRewrite (cir::AllocaOp op, OpAdaptor adaptor,
164167 mlir::ConversionPatternRewriter &rewriter) const override {
165168
166- mlir::Type mlirType =
167- convertTypeForMemory (*getTypeConverter (), adaptor. getAllocaType () );
169+ mlir::Type allocaType = adaptor. getAllocaType ();
170+ mlir::Type mlirType = convertTypeForMemory (*getTypeConverter (), allocaType );
168171
169172 // FIXME: Some types can not be converted yet (e.g. struct)
170173 if (!mlirType)
171174 return mlir::LogicalResult::failure ();
172175
173176 auto memreftype = mlir::dyn_cast<mlir::MemRefType>(mlirType);
174- if (memreftype && mlir::isa<cir::ArrayType>(adaptor. getAllocaType ())) {
175- // if the type is an array,
176- // we don't need to wrap with memref .
177+ if (memreftype && ( mlir::isa<cir::ArrayType>(allocaType) ||
178+ mlir::isa<cir::RecordType>(allocaType))) {
179+ // Arrays and structs are already memref. No need to wrap another one .
177180 } else {
178181 memreftype = mlir::MemRefType::get ({}, mlirType);
179182 }
@@ -1240,6 +1243,36 @@ class CIRPtrStrideOpLowering
12401243 }
12411244};
12421245
1246+ class CIRGetMemberOpLowering
1247+ : public mlir::OpConversionPattern<cir::GetMemberOp> {
1248+ public:
1249+ CIRGetMemberOpLowering (mlir::TypeConverter &converter, mlir::MLIRContext *ctx,
1250+ const mlir::DataLayout &layout)
1251+ : OpConversionPattern(converter, ctx), layout(layout) {}
1252+
1253+ mlir::LogicalResult
1254+ matchAndRewrite (cir::GetMemberOp op, OpAdaptor adaptor,
1255+ mlir::ConversionPatternRewriter &rewriter) const override {
1256+ auto baseAddr = op.getAddr ();
1257+ auto structType =
1258+ mlir::cast<cir::RecordType>(baseAddr.getType ().getPointee ());
1259+ uint64_t byteOffset = structType.getElementOffset (layout, op.getIndex ());
1260+
1261+ auto fieldType = op.getResult ().getType ();
1262+ auto resultType = mlir::cast<mlir::MemRefType>(
1263+ getTypeConverter ()->convertType (fieldType));
1264+
1265+ mlir::Value offsetValue =
1266+ rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), byteOffset);
1267+ rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
1268+ op, resultType, adaptor.getAddr (), offsetValue, mlir::ValueRange{});
1269+ return mlir::success ();
1270+ }
1271+
1272+ private:
1273+ const mlir::DataLayout &layout;
1274+ };
1275+
12431276class CIRUnreachableOpLowering
12441277 : public mlir::OpConversionPattern<cir::UnreachableOp> {
12451278public:
@@ -1271,7 +1304,8 @@ class CIRTrapOpLowering : public mlir::OpConversionPattern<cir::TrapOp> {
12711304};
12721305
12731306void populateCIRToMLIRConversionPatterns (mlir::RewritePatternSet &patterns,
1274- mlir::TypeConverter &converter) {
1307+ mlir::TypeConverter &converter,
1308+ mlir::DataLayout layout) {
12751309 patterns.add <CIRReturnLowering, CIRBrOpLowering>(patterns.getContext ());
12761310
12771311 patterns
@@ -1292,16 +1326,20 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
12921326 CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRACosOpLowering,
12931327 CIRASinOpLowering, CIRUnreachableOpLowering, CIRTanOpLowering,
12941328 CIRTrapOpLowering>(converter, patterns.getContext ());
1329+
1330+ patterns.add <CIRGetMemberOpLowering>(converter, patterns.getContext (),
1331+ layout);
12951332}
12961333
1297- static mlir::TypeConverter prepareTypeConverter () {
1334+ static mlir::TypeConverter prepareTypeConverter (mlir::DataLayout layout ) {
12981335 mlir::TypeConverter converter;
12991336 converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
1300- auto ty = convertTypeForMemory (converter, type.getPointee ());
1337+ auto pointee = type.getPointee ();
1338+ auto ty = convertTypeForMemory (converter, pointee);
13011339 // FIXME: The pointee type might not be converted (e.g. struct)
13021340 if (!ty)
13031341 return nullptr ;
1304- if (isa<cir::ArrayType>(type. getPointee () ))
1342+ if (isa<cir::ArrayType>(pointee) || isa<cir::RecordType>(pointee ))
13051343 return ty;
13061344 return mlir::MemRefType::get ({}, ty);
13071345 });
@@ -1353,6 +1391,13 @@ static mlir::TypeConverter prepareTypeConverter() {
13531391 return nullptr ;
13541392 return mlir::MemRefType::get (shape, elementType);
13551393 });
1394+ converter.addConversion ([&](cir::RecordType type) -> mlir::Type {
1395+ // Reinterpret structs as raw bytes. Don't use tuples as they can't be put
1396+ // in memref.
1397+ auto size = type.getTypeSize (layout, {});
1398+ auto i8 = mlir::IntegerType::get (type.getContext (), /* width=*/ 8 );
1399+ return mlir::MemRefType::get (size.getFixedValue (), i8 );
1400+ });
13561401 converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
13571402 auto ty = converter.convertType (type.getEltType ());
13581403 return mlir::VectorType::get (type.getSize (), ty);
@@ -1363,13 +1408,15 @@ static mlir::TypeConverter prepareTypeConverter() {
13631408
13641409void ConvertCIRToMLIRPass::runOnOperation () {
13651410 auto module = getOperation ();
1411+ mlir::DataLayoutAnalysis layoutAnalysis (module );
1412+ const mlir::DataLayout &layout = layoutAnalysis.getAtOrAbove (module );
13661413
1367- auto converter = prepareTypeConverter ();
1414+ auto converter = prepareTypeConverter (layout );
13681415
13691416 mlir::RewritePatternSet patterns (&getContext ());
13701417
13711418 populateCIRLoopToSCFConversionPatterns (patterns, converter);
1372- populateCIRToMLIRConversionPatterns (patterns, converter);
1419+ populateCIRToMLIRConversionPatterns (patterns, converter, layout );
13731420
13741421 mlir::ConversionTarget target (getContext ());
13751422 target.addLegalOp <mlir::ModuleOp>();
0 commit comments