From 85ad9984ebe8f398f525f3bcd257648037d1dd90 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:40:31 -0700 Subject: [PATCH 01/17] Prepare basic GlobalISel setup and implement CallLowering::lowerFormalArguments and CallLowering::lowerReturn --- llvm/lib/Target/WebAssembly/CMakeLists.txt | 4 + .../GISel/WebAssemblyCallLowering.cpp | 687 ++++++++++++++++++ .../GISel/WebAssemblyCallLowering.h | 43 ++ .../GISel/WebAssemblyInstructionSelector.cpp | 0 .../GISel/WebAssemblyInstructionSelector.h | 0 .../GISel/WebAssemblyLegalizerInfo.cpp | 23 + .../GISel/WebAssemblyLegalizerInfo.h | 29 + .../GISel/WebAssemblyRegisterBankInfo.cpp | 0 .../GISel/WebAssemblyRegisterBankInfo.h | 0 .../WebAssembly/WebAssemblySubtarget.cpp | 30 +- .../Target/WebAssembly/WebAssemblySubtarget.h | 14 + .../WebAssembly/WebAssemblyTargetMachine.cpp | 30 + 12 files changed, 859 insertions(+), 1 deletion(-) create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index 17df119d62709..ffb4ad182c81b 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -16,6 +16,10 @@ tablegen(LLVM WebAssemblyGenSubtargetInfo.inc -gen-subtarget) add_public_tablegen_target(WebAssemblyCommonTableGen) add_llvm_target(WebAssemblyCodeGen + GISel/WebAssemblyCallLowering.cpp + GISel/WebAssemblyInstructionSelector.cpp + GISel/WebAssemblyRegisterBankInfo.cpp + GISel/WebAssemblyLegalizerInfo.cpp WebAssemblyAddMissingPrototypes.cpp WebAssemblyArgumentMove.cpp WebAssemblyAsmPrinter.cpp diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp new file mode 100644 index 0000000000000..5949d26a83840 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -0,0 +1,687 @@ +//===-- WebAssemblyCallLowering.cpp - Call lowering for GlobalISel -*- C++ -*-// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the lowering of LLVM calls to machine code calls for +/// GlobalISel. +/// +//===----------------------------------------------------------------------===// + +#include "WebAssemblyCallLowering.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "WebAssemblyISelLowering.h" +#include "WebAssemblyMachineFunctionInfo.h" +#include "WebAssemblySubtarget.h" +#include "WebAssemblyUtilities.h" +#include "llvm/CodeGen/Analysis.h" +#include "llvm/CodeGen/FunctionLoweringInfo.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/LowLevelTypeUtils.h" +#include "llvm/CodeGenTypes/LowLevelType.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" + +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "wasm-call-lowering" + +using namespace llvm; + +// Several of the following methods are internal utilities defined in +// CodeGen/GlobalIsel/CallLowering.cpp +// TODO: Find a better solution? + +// Internal utility from CallLowering.cpp +static unsigned extendOpFromFlags(ISD::ArgFlagsTy Flags) { + if (Flags.isSExt()) + return TargetOpcode::G_SEXT; + if (Flags.isZExt()) + return TargetOpcode::G_ZEXT; + return TargetOpcode::G_ANYEXT; +} + +// Internal utility from CallLowering.cpp +/// Pack values \p SrcRegs to cover the vector type result \p DstRegs. +static MachineInstrBuilder +mergeVectorRegsToResultRegs(MachineIRBuilder &B, ArrayRef DstRegs, + ArrayRef SrcRegs) { + MachineRegisterInfo &MRI = *B.getMRI(); + LLT LLTy = MRI.getType(DstRegs[0]); + LLT PartLLT = MRI.getType(SrcRegs[0]); + + // Deal with v3s16 split into v2s16 + LLT LCMTy = getCoverTy(LLTy, PartLLT); + if (LCMTy == LLTy) { + // Common case where no padding is needed. + assert(DstRegs.size() == 1); + return B.buildConcatVectors(DstRegs[0], SrcRegs); + } + + // We need to create an unmerge to the result registers, which may require + // widening the original value. + Register UnmergeSrcReg; + if (LCMTy != PartLLT) { + assert(DstRegs.size() == 1); + return B.buildDeleteTrailingVectorElements( + DstRegs[0], B.buildMergeLikeInstr(LCMTy, SrcRegs)); + } else { + // We don't need to widen anything if we're extracting a scalar which was + // promoted to a vector e.g. s8 -> v4s8 -> s8 + assert(SrcRegs.size() == 1); + UnmergeSrcReg = SrcRegs[0]; + } + + int NumDst = LCMTy.getSizeInBits() / LLTy.getSizeInBits(); + + SmallVector PadDstRegs(NumDst); + llvm::copy(DstRegs, PadDstRegs.begin()); + + // Create the excess dead defs for the unmerge. + for (int I = DstRegs.size(); I != NumDst; ++I) + PadDstRegs[I] = MRI.createGenericVirtualRegister(LLTy); + + if (PadDstRegs.size() == 1) + return B.buildDeleteTrailingVectorElements(DstRegs[0], UnmergeSrcReg); + return B.buildUnmerge(PadDstRegs, UnmergeSrcReg); +} + +// Internal utility from CallLowering.cpp +/// Create a sequence of instructions to combine pieces split into register +/// typed values to the original IR value. \p OrigRegs contains the destination +/// value registers of type \p LLTy, and \p Regs contains the legalized pieces +/// with type \p PartLLT. This is used for incoming values (physregs to vregs). +static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef OrigRegs, + ArrayRef Regs, LLT LLTy, LLT PartLLT, + const ISD::ArgFlagsTy Flags) { + MachineRegisterInfo &MRI = *B.getMRI(); + + if (PartLLT == LLTy) { + // We should have avoided introducing a new virtual register, and just + // directly assigned here. + assert(OrigRegs[0] == Regs[0]); + return; + } + + if (PartLLT.getSizeInBits() == LLTy.getSizeInBits() && OrigRegs.size() == 1 && + Regs.size() == 1) { + B.buildBitcast(OrigRegs[0], Regs[0]); + return; + } + + // A vector PartLLT needs extending to LLTy's element size. + // E.g. <2 x s64> = G_SEXT <2 x s32>. + if (PartLLT.isVector() == LLTy.isVector() && + PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() && + (!PartLLT.isVector() || + PartLLT.getElementCount() == LLTy.getElementCount()) && + OrigRegs.size() == 1 && Regs.size() == 1) { + Register SrcReg = Regs[0]; + + LLT LocTy = MRI.getType(SrcReg); + + if (Flags.isSExt()) { + SrcReg = B.buildAssertSExt(LocTy, SrcReg, LLTy.getScalarSizeInBits()) + .getReg(0); + } else if (Flags.isZExt()) { + SrcReg = B.buildAssertZExt(LocTy, SrcReg, LLTy.getScalarSizeInBits()) + .getReg(0); + } + + // Sometimes pointers are passed zero extended. + LLT OrigTy = MRI.getType(OrigRegs[0]); + if (OrigTy.isPointer()) { + LLT IntPtrTy = LLT::scalar(OrigTy.getSizeInBits()); + B.buildIntToPtr(OrigRegs[0], B.buildTrunc(IntPtrTy, SrcReg)); + return; + } + + B.buildTrunc(OrigRegs[0], SrcReg); + return; + } + + if (!LLTy.isVector() && !PartLLT.isVector()) { + assert(OrigRegs.size() == 1); + LLT OrigTy = MRI.getType(OrigRegs[0]); + + unsigned SrcSize = PartLLT.getSizeInBits().getFixedValue() * Regs.size(); + if (SrcSize == OrigTy.getSizeInBits()) + B.buildMergeValues(OrigRegs[0], Regs); + else { + auto Widened = B.buildMergeLikeInstr(LLT::scalar(SrcSize), Regs); + B.buildTrunc(OrigRegs[0], Widened); + } + + return; + } + + if (PartLLT.isVector()) { + assert(OrigRegs.size() == 1); + SmallVector CastRegs(Regs); + + // If PartLLT is a mismatched vector in both number of elements and element + // size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to + // have the same elt type, i.e. v4s32. + // TODO: Extend this coersion to element multiples other than just 2. + if (TypeSize::isKnownGT(PartLLT.getSizeInBits(), LLTy.getSizeInBits()) && + PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 && + Regs.size() == 1) { + LLT NewTy = PartLLT.changeElementType(LLTy.getElementType()) + .changeElementCount(PartLLT.getElementCount() * 2); + CastRegs[0] = B.buildBitcast(NewTy, Regs[0]).getReg(0); + PartLLT = NewTy; + } + + if (LLTy.getScalarType() == PartLLT.getElementType()) { + mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs); + } else { + unsigned I = 0; + LLT GCDTy = getGCDType(LLTy, PartLLT); + + // We are both splitting a vector, and bitcasting its element types. Cast + // the source pieces into the appropriate number of pieces with the result + // element type. + for (Register SrcReg : CastRegs) + CastRegs[I++] = B.buildBitcast(GCDTy, SrcReg).getReg(0); + mergeVectorRegsToResultRegs(B, OrigRegs, CastRegs); + } + + return; + } + + assert(LLTy.isVector() && !PartLLT.isVector()); + + LLT DstEltTy = LLTy.getElementType(); + + // Pointer information was discarded. We'll need to coerce some register types + // to avoid violating type constraints. + LLT RealDstEltTy = MRI.getType(OrigRegs[0]).getElementType(); + + assert(DstEltTy.getSizeInBits() == RealDstEltTy.getSizeInBits()); + + if (DstEltTy == PartLLT) { + // Vector was trivially scalarized. + + if (RealDstEltTy.isPointer()) { + for (Register Reg : Regs) + MRI.setType(Reg, RealDstEltTy); + } + + B.buildBuildVector(OrigRegs[0], Regs); + } else if (DstEltTy.getSizeInBits() > PartLLT.getSizeInBits()) { + // Deal with vector with 64-bit elements decomposed to 32-bit + // registers. Need to create intermediate 64-bit elements. + SmallVector EltMerges; + int PartsPerElt = + divideCeil(DstEltTy.getSizeInBits(), PartLLT.getSizeInBits()); + LLT ExtendedPartTy = LLT::scalar(PartLLT.getSizeInBits() * PartsPerElt); + + for (int I = 0, NumElts = LLTy.getNumElements(); I != NumElts; ++I) { + auto Merge = + B.buildMergeLikeInstr(ExtendedPartTy, Regs.take_front(PartsPerElt)); + if (ExtendedPartTy.getSizeInBits() > RealDstEltTy.getSizeInBits()) + Merge = B.buildTrunc(RealDstEltTy, Merge); + // Fix the type in case this is really a vector of pointers. + MRI.setType(Merge.getReg(0), RealDstEltTy); + EltMerges.push_back(Merge.getReg(0)); + Regs = Regs.drop_front(PartsPerElt); + } + + B.buildBuildVector(OrigRegs[0], EltMerges); + } else { + // Vector was split, and elements promoted to a wider type. + // FIXME: Should handle floating point promotions. + unsigned NumElts = LLTy.getNumElements(); + LLT BVType = LLT::fixed_vector(NumElts, PartLLT); + + Register BuildVec; + if (NumElts == Regs.size()) + BuildVec = B.buildBuildVector(BVType, Regs).getReg(0); + else { + // Vector elements are packed in the inputs. + // e.g. we have a <4 x s16> but 2 x s32 in regs. + assert(NumElts > Regs.size()); + LLT SrcEltTy = MRI.getType(Regs[0]); + + LLT OriginalEltTy = MRI.getType(OrigRegs[0]).getElementType(); + + // Input registers contain packed elements. + // Determine how many elements per reg. + assert((SrcEltTy.getSizeInBits() % OriginalEltTy.getSizeInBits()) == 0); + unsigned EltPerReg = + (SrcEltTy.getSizeInBits() / OriginalEltTy.getSizeInBits()); + + SmallVector BVRegs; + BVRegs.reserve(Regs.size() * EltPerReg); + for (Register R : Regs) { + auto Unmerge = B.buildUnmerge(OriginalEltTy, R); + for (unsigned K = 0; K < EltPerReg; ++K) + BVRegs.push_back(B.buildAnyExt(PartLLT, Unmerge.getReg(K)).getReg(0)); + } + + // We may have some more elements in BVRegs, e.g. if we have 2 s32 pieces + // for a <3 x s16> vector. We should have less than EltPerReg extra items. + if (BVRegs.size() > NumElts) { + assert((BVRegs.size() - NumElts) < EltPerReg); + BVRegs.truncate(NumElts); + } + BuildVec = B.buildBuildVector(BVType, BVRegs).getReg(0); + } + B.buildTrunc(OrigRegs[0], BuildVec); + } +} + +// Internal utility from CallLowering.cpp +/// Create a sequence of instructions to expand the value in \p SrcReg (of type +/// \p SrcTy) to the types in \p DstRegs (of type \p PartTy). \p ExtendOp should +/// contain the type of scalar value extension if necessary. +/// +/// This is used for outgoing values (vregs to physregs) +static void buildCopyToRegs(MachineIRBuilder &B, ArrayRef DstRegs, + Register SrcReg, LLT SrcTy, LLT PartTy, + unsigned ExtendOp = TargetOpcode::G_ANYEXT) { + // We could just insert a regular copy, but this is unreachable at the moment. + assert(SrcTy != PartTy && "identical part types shouldn't reach here"); + + const TypeSize PartSize = PartTy.getSizeInBits(); + + if (PartTy.isVector() == SrcTy.isVector() && + PartTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits()) { + assert(DstRegs.size() == 1); + B.buildInstr(ExtendOp, {DstRegs[0]}, {SrcReg}); + return; + } + + if (SrcTy.isVector() && !PartTy.isVector() && + TypeSize::isKnownGT(PartSize, SrcTy.getElementType().getSizeInBits())) { + // Vector was scalarized, and the elements extended. + auto UnmergeToEltTy = B.buildUnmerge(SrcTy.getElementType(), SrcReg); + for (int i = 0, e = DstRegs.size(); i != e; ++i) + B.buildAnyExt(DstRegs[i], UnmergeToEltTy.getReg(i)); + return; + } + + if (SrcTy.isVector() && PartTy.isVector() && + PartTy.getSizeInBits() == SrcTy.getSizeInBits() && + ElementCount::isKnownLT(SrcTy.getElementCount(), + PartTy.getElementCount())) { + // A coercion like: v2f32 -> v4f32 or nxv2f32 -> nxv4f32 + Register DstReg = DstRegs.front(); + B.buildPadVectorWithUndefElements(DstReg, SrcReg); + return; + } + + LLT GCDTy = getGCDType(SrcTy, PartTy); + if (GCDTy == PartTy) { + // If this already evenly divisible, we can create a simple unmerge. + B.buildUnmerge(DstRegs, SrcReg); + return; + } + + if (SrcTy.isVector() && !PartTy.isVector() && + SrcTy.getScalarSizeInBits() > PartTy.getSizeInBits()) { + LLT ExtTy = + LLT::vector(SrcTy.getElementCount(), + LLT::scalar(PartTy.getScalarSizeInBits() * DstRegs.size() / + SrcTy.getNumElements())); + auto Ext = B.buildAnyExt(ExtTy, SrcReg); + B.buildUnmerge(DstRegs, Ext); + return; + } + + MachineRegisterInfo &MRI = *B.getMRI(); + LLT DstTy = MRI.getType(DstRegs[0]); + LLT LCMTy = getCoverTy(SrcTy, PartTy); + + if (PartTy.isVector() && LCMTy == PartTy) { + assert(DstRegs.size() == 1); + B.buildPadVectorWithUndefElements(DstRegs[0], SrcReg); + return; + } + + const unsigned DstSize = DstTy.getSizeInBits(); + const unsigned SrcSize = SrcTy.getSizeInBits(); + unsigned CoveringSize = LCMTy.getSizeInBits(); + + Register UnmergeSrc = SrcReg; + + if (!LCMTy.isVector() && CoveringSize != SrcSize) { + // For scalars, it's common to be able to use a simple extension. + if (SrcTy.isScalar() && DstTy.isScalar()) { + CoveringSize = alignTo(SrcSize, DstSize); + LLT CoverTy = LLT::scalar(CoveringSize); + UnmergeSrc = B.buildInstr(ExtendOp, {CoverTy}, {SrcReg}).getReg(0); + } else { + // Widen to the common type. + // FIXME: This should respect the extend type + Register Undef = B.buildUndef(SrcTy).getReg(0); + SmallVector MergeParts(1, SrcReg); + for (unsigned Size = SrcSize; Size != CoveringSize; Size += SrcSize) + MergeParts.push_back(Undef); + UnmergeSrc = B.buildMergeLikeInstr(LCMTy, MergeParts).getReg(0); + } + } + + if (LCMTy.isVector() && CoveringSize != SrcSize) + UnmergeSrc = B.buildPadVectorWithUndefElements(LCMTy, SrcReg).getReg(0); + + B.buildUnmerge(DstRegs, UnmergeSrc); +} + +// Test whether the given calling convention is supported. +static bool callingConvSupported(CallingConv::ID CallConv) { + // We currently support the language-independent target-independent + // conventions. We don't yet have a way to annotate calls with properties like + // "cold", and we don't have any call-clobbered registers, so these are mostly + // all handled the same. + return CallConv == CallingConv::C || CallConv == CallingConv::Fast || + CallConv == CallingConv::Cold || + CallConv == CallingConv::PreserveMost || + CallConv == CallingConv::PreserveAll || + CallConv == CallingConv::CXX_FAST_TLS || + CallConv == CallingConv::WASM_EmscriptenInvoke || + CallConv == CallingConv::Swift; +} + +static void fail(MachineIRBuilder &MIRBuilder, const char *Msg) { + MachineFunction &MF = MIRBuilder.getMF(); + MIRBuilder.getContext().diagnose( + DiagnosticInfoUnsupported(MF.getFunction(), Msg, MIRBuilder.getDL())); +} + +WebAssemblyCallLowering::WebAssemblyCallLowering( + const WebAssemblyTargetLowering &TLI) + : CallLowering(&TLI) {} + +bool WebAssemblyCallLowering::canLowerReturn(MachineFunction &MF, + CallingConv::ID CallConv, + SmallVectorImpl &Outs, + bool IsVarArg) const { + return WebAssembly::canLowerReturn(Outs.size(), + &MF.getSubtarget()); +} + +bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, + const Value *Val, + ArrayRef VRegs, + FunctionLoweringInfo &FLI, + Register SwiftErrorVReg) const { + auto MIB = MIRBuilder.buildInstrNoInsert(WebAssembly::RETURN); + + assert(((Val && !VRegs.empty()) || (!Val && VRegs.empty())) && + "Return value without a vreg"); + + if (Val && !FLI.CanLowerReturn) { + insertSRetStores(MIRBuilder, Val->getType(), VRegs, FLI.DemoteRegister); + } else if (!VRegs.empty()) { + MachineFunction &MF = MIRBuilder.getMF(); + const Function &F = MF.getFunction(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const WebAssemblyTargetLowering &TLI = *getTLI(); + auto &DL = F.getDataLayout(); + LLVMContext &Ctx = Val->getType()->getContext(); + + SmallVector SplitEVTs; + ComputeValueVTs(TLI, DL, Val->getType(), SplitEVTs); + assert(VRegs.size() == SplitEVTs.size() && + "For each split Type there should be exactly one VReg."); + + SmallVector SplitArgs; + CallingConv::ID CallConv = F.getCallingConv(); + + unsigned i = 0; + for (auto SplitEVT : SplitEVTs) { + Register CurVReg = VRegs[i]; + ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0}; + setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, F); + + splitToValueTypes(CurArgInfo, SplitArgs, DL, CallConv); + ++i; + } + + for (auto &Arg : SplitArgs) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Arg.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + + // If we need to split the type over multiple regs, check it's a scenario + // we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Arg.Flags[0]; + Arg.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Arg.Flags.push_back(Flags); + } + + Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Arg.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT, + extendOpFromFlags(Arg.Flags[0])); + } + + for (unsigned Part = 0; Part < NumParts; ++Part) { + MIB.addUse(Arg.Regs[Part]); + } + } + } + + if (SwiftErrorVReg) { + llvm_unreachable("WASM does not `supportSwiftError`, yet SwiftErrorVReg is " + "improperly valid."); + } + + MIRBuilder.insertInstr(MIB); + return true; +} + +static unsigned getWASMArgOpcode(MVT ArgType) { +#define MVT_CASE(type) \ + case MVT::type: \ + return WebAssembly::ARGUMENT_##type; + + switch (ArgType.SimpleTy) { + MVT_CASE(i32) + MVT_CASE(i64) + MVT_CASE(f32) + MVT_CASE(f64) + MVT_CASE(funcref) + MVT_CASE(externref) + MVT_CASE(exnref) + MVT_CASE(v16i8) + MVT_CASE(v8i16) + MVT_CASE(v4i32) + MVT_CASE(v2i64) + MVT_CASE(v4f32) + MVT_CASE(v2f64) + MVT_CASE(v8f16) + default: + break; + } + llvm_unreachable("Found unexpected type for WASM argument"); + +#undef MVT_CASE +} + +bool WebAssemblyCallLowering::lowerFormalArguments( + MachineIRBuilder &MIRBuilder, const Function &F, + ArrayRef> VRegs, FunctionLoweringInfo &FLI) const { + + MachineFunction &MF = MIRBuilder.getMF(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + WebAssemblyFunctionInfo *MFI = MF.getInfo(); + const DataLayout &DL = F.getDataLayout(); + auto &TLI = *getTLI(); + LLVMContext &Ctx = MIRBuilder.getContext(); + const CallingConv::ID CallConv = F.getCallingConv(); + + if (!callingConvSupported(F.getCallingConv())) { + fail(MIRBuilder, "WebAssembly doesn't support non-C calling conventions"); + return false; + } + + // Set up the live-in for the incoming ARGUMENTS. + MF.getRegInfo().addLiveIn(WebAssembly::ARGUMENTS); + + SmallVector SplitArgs; + + if (!FLI.CanLowerReturn) { + dbgs() << "grath\n"; + insertSRetIncomingArgument(F, SplitArgs, FLI.DemoteRegister, MRI, DL); + } + unsigned i = 0; + + bool HasSwiftErrorArg = false; + bool HasSwiftSelfArg = false; + for (const auto &Arg : F.args()) { + ArgInfo OrigArg{VRegs[i], Arg.getType(), i}; + setArgFlags(OrigArg, i + AttributeList::FirstArgIndex, DL, F); + + HasSwiftSelfArg |= Arg.hasSwiftSelfAttr(); + HasSwiftErrorArg |= Arg.hasSwiftErrorAttr(); + if (Arg.hasInAllocaAttr()) { + fail(MIRBuilder, "WebAssembly hasn't implemented inalloca arguments"); + return false; + } + if (Arg.hasNestAttr()) { + fail(MIRBuilder, "WebAssembly hasn't implemented nest arguments"); + return false; + } + splitToValueTypes(OrigArg, SplitArgs, DL, F.getCallingConv()); + ++i; + } + + unsigned FinalArgIdx = 0; + for (auto &Arg : SplitArgs) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Arg.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + + // If we need to split the type over multiple regs, check it's a scenario + // we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Arg.Flags[0]; + Arg.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Arg.Flags.push_back(Flags); + } + + Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Arg.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + buildCopyFromRegs(MIRBuilder, Arg.OrigRegs, Arg.Regs, OrigLLT, NewLLT, + Arg.Flags[0]); + } + + for (unsigned Part = 0; Part < NumParts; ++Part) { + MIRBuilder.buildInstr(getWASMArgOpcode(NewVT)) + .addDef(Arg.Regs[Part]) + .addImm(FinalArgIdx); + MFI->addParam(NewVT); + ++FinalArgIdx; + } + } + + /**/ + + // For swiftcc, emit additional swiftself and swifterror arguments + // if there aren't. These additional arguments are also added for callee + // signature They are necessary to match callee and caller signature for + // indirect call. + auto PtrVT = TLI.getPointerTy(DL); + if (CallConv == CallingConv::Swift) { + if (!HasSwiftSelfArg) { + MFI->addParam(PtrVT); + } + if (!HasSwiftErrorArg) { + MFI->addParam(PtrVT); + } + } + + // Varargs are copied into a buffer allocated by the caller, and a pointer to + // the buffer is passed as an argument. + if (F.isVarArg()) { + auto PtrVT = TLI.getPointerTy(DL); + Register VarargVreg = MF.getRegInfo().createGenericVirtualRegister( + getLLTForType(*PointerType::get(Ctx, 0), DL)); + MFI->setVarargBufferVreg(VarargVreg); + + MIRBuilder.buildInstr(getWASMArgOpcode(PtrVT)) + .addDef(VarargVreg) + .addImm(FinalArgIdx); + + MFI->addParam(PtrVT); + ++FinalArgIdx; + } + + // Record the number and types of arguments and results. + SmallVector Params; + SmallVector Results; + computeSignatureVTs(MF.getFunction().getFunctionType(), &MF.getFunction(), + MF.getFunction(), MF.getTarget(), Params, Results); + for (MVT VT : Results) + MFI->addResult(VT); + + // TODO: Use signatures in WebAssemblyMachineFunctionInfo too and unify + // the param logic here with ComputeSignatureVTs + assert(MFI->getParams().size() == Params.size() && + std::equal(MFI->getParams().begin(), MFI->getParams().end(), + Params.begin())); + return true; +} + +bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, + CallLoweringInfo &Info) const { + return false; +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h new file mode 100644 index 0000000000000..d22f7cbd17eb3 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.h @@ -0,0 +1,43 @@ +//===-- WebAssemblyCallLowering.h - Call lowering for GlobalISel -*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file describes how to lower LLVM calls to machine code calls. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYCALLLOWERING_H +#define LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYCALLLOWERING_H + +#include "WebAssemblyISelLowering.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" +#include "llvm/IR/CallingConv.h" + +namespace llvm { + +class WebAssemblyTargetLowering; + +class WebAssemblyCallLowering : public CallLowering { +public: + WebAssemblyCallLowering(const WebAssemblyTargetLowering &TLI); + + bool canLowerReturn(MachineFunction &MF, CallingConv::ID CallConv, + SmallVectorImpl &Outs, + bool IsVarArg) const override; + bool lowerReturn(MachineIRBuilder &MIRBuilder, const Value *Val, + ArrayRef VRegs, FunctionLoweringInfo &FLI, + Register SwiftErrorVReg) const override; + bool lowerFormalArguments(MachineIRBuilder &MIRBuilder, const Function &F, + ArrayRef> VRegs, + FunctionLoweringInfo &FLI) const override; + bool lowerCall(MachineIRBuilder &MIRBuilder, + CallLoweringInfo &Info) const override; +}; +} // namespace llvm + +#endif diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp new file mode 100644 index 0000000000000..3acdabb5612cc --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -0,0 +1,23 @@ +//===- WebAssemblyLegalizerInfo.h --------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file implements the targeting of the Machinelegalizer class for +/// WebAssembly +//===----------------------------------------------------------------------===// + +#include "WebAssemblyLegalizerInfo.h" + +#define DEBUG_TYPE "wasm-legalinfo" + +using namespace llvm; +using namespace LegalizeActions; + +WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( + const WebAssemblySubtarget &ST) { + getLegacyLegalizerInfo().computeTables(); +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h new file mode 100644 index 0000000000000..c02205fc7ae0d --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h @@ -0,0 +1,29 @@ +//===- WebAssemblyLegalizerInfo.h --------------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file declares the targeting of the Machinelegalizer class for +/// WebAssembly +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYMACHINELEGALIZER_H +#define LLVM_LIB_TARGET_WEBASSEMBLY_GISEL_WEBASSEMBLYMACHINELEGALIZER_H + +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" + +namespace llvm { + +class WebAssemblySubtarget; + +/// This class provides the information for the BPF target legalizer for +/// GlobalISel. +class WebAssemblyLegalizerInfo : public LegalizerInfo { +public: + WebAssemblyLegalizerInfo(const WebAssemblySubtarget &ST); +}; +} // namespace llvm +#endif diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp index a3ce40f0297ec..3ea8b9f85819f 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp @@ -13,8 +13,12 @@ //===----------------------------------------------------------------------===// #include "WebAssemblySubtarget.h" +#include "GISel/WebAssemblyCallLowering.h" +#include "GISel/WebAssemblyLegalizerInfo.h" +#include "GISel/WebAssemblyRegisterBankInfo.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" #include "WebAssemblyInstrInfo.h" +#include "WebAssemblyTargetMachine.h" #include "llvm/MC/TargetRegistry.h" using namespace llvm; @@ -66,7 +70,15 @@ WebAssemblySubtarget::WebAssemblySubtarget(const Triple &TT, const TargetMachine &TM) : WebAssemblyGenSubtargetInfo(TT, CPU, /*TuneCPU*/ CPU, FS), TargetTriple(TT), InstrInfo(initializeSubtargetDependencies(CPU, FS)), - TLInfo(TM, *this) {} + TLInfo(TM, *this) { + CallLoweringInfo.reset(new WebAssemblyCallLowering(*getTargetLowering())); + Legalizer.reset(new WebAssemblyLegalizerInfo(*this)); + /*auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo()); + RegBankInfo.reset(RBI); + + InstSelector.reset(createWebAssemblyInstructionSelector( + *static_cast(&TM), *this, *RBI));*/ +} bool WebAssemblySubtarget::enableAtomicExpand() const { // If atomics are disabled, atomic ops are lowered instead of expanded @@ -81,3 +93,19 @@ bool WebAssemblySubtarget::enableMachineScheduler() const { } bool WebAssemblySubtarget::useAA() const { return true; } + +const CallLowering *WebAssemblySubtarget::getCallLowering() const { + return CallLoweringInfo.get(); +} + +InstructionSelector *WebAssemblySubtarget::getInstructionSelector() const { + return InstSelector.get(); +} + +const LegalizerInfo *WebAssemblySubtarget::getLegalizerInfo() const { + return Legalizer.get(); +} + +const RegisterBankInfo *WebAssemblySubtarget::getRegBankInfo() const { + return RegBankInfo.get(); +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h index 2f88bbba05d00..c195f995009b1 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h @@ -20,6 +20,10 @@ #include "WebAssemblyISelLowering.h" #include "WebAssemblyInstrInfo.h" #include "WebAssemblySelectionDAGInfo.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" +#include "llvm/CodeGen/RegisterBankInfo.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include @@ -64,6 +68,11 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo { WebAssemblySelectionDAGInfo TSInfo; WebAssemblyTargetLowering TLInfo; + std::unique_ptr CallLoweringInfo; + std::unique_ptr InstSelector; + std::unique_ptr Legalizer; + std::unique_ptr RegBankInfo; + WebAssemblySubtarget &initializeSubtargetDependencies(StringRef CPU, StringRef FS); @@ -118,6 +127,11 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo { /// Parses features string setting specified subtarget options. Definition of /// function is auto generated by tblgen. void ParseSubtargetFeatures(StringRef CPU, StringRef TuneCPU, StringRef FS); + + const CallLowering *getCallLowering() const override; + InstructionSelector *getInstructionSelector() const override; + const LegalizerInfo *getLegalizerInfo() const override; + const RegisterBankInfo *getRegBankInfo() const override; }; } // end namespace llvm diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp index 621640c12f695..66959f9c2ac43 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -20,6 +20,10 @@ #include "WebAssemblyTargetObjectFile.h" #include "WebAssemblyTargetTransformInfo.h" #include "WebAssemblyUtilities.h" +#include "llvm/CodeGen/GlobalISel/IRTranslator.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelect.h" +#include "llvm/CodeGen/GlobalISel/Legalizer.h" +#include "llvm/CodeGen/GlobalISel/RegBankSelect.h" #include "llvm/CodeGen/MIRParser/MIParser.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/RegAllocRegistry.h" @@ -92,6 +96,7 @@ LLVMInitializeWebAssemblyTarget() { // Register backend passes auto &PR = *PassRegistry::getPassRegistry(); + initializeGlobalISel(PR); initializeWebAssemblyAddMissingPrototypesPass(PR); initializeWebAssemblyLowerEmscriptenEHSjLjPass(PR); initializeLowerGlobalDtorsLegacyPassPass(PR); @@ -440,6 +445,11 @@ class WebAssemblyPassConfig final : public TargetPassConfig { // No reg alloc bool addRegAssignAndRewriteOptimized() override { return false; } + + bool addIRTranslator() override; + bool addLegalizeMachineIR() override; + bool addRegBankSelect() override; + bool addGlobalInstructionSelect() override; }; } // end anonymous namespace @@ -660,6 +670,26 @@ bool WebAssemblyPassConfig::addPreISel() { return false; } +bool WebAssemblyPassConfig::addIRTranslator() { + addPass(new IRTranslator()); + return false; +} + +bool WebAssemblyPassConfig::addLegalizeMachineIR() { + addPass(new Legalizer()); + return false; +} + +bool WebAssemblyPassConfig::addRegBankSelect() { + addPass(new RegBankSelect()); + return false; +} + +bool WebAssemblyPassConfig::addGlobalInstructionSelect() { + addPass(new InstructionSelect(getOptLevel())); + return false; +} + yaml::MachineFunctionInfo * WebAssemblyTargetMachine::createDefaultFuncInfoYAML() const { return new yaml::WebAssemblyFunctionInfo(); From 20b8b4c1d2eb58e4a202ed4ef7c7d57edb760d64 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:41:02 -0700 Subject: [PATCH 02/17] Implement WebAssemblyCallLowering::lowerCall --- .../GISel/WebAssemblyCallLowering.cpp | 453 +++++++++++++++++- 1 file changed, 451 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 5949d26a83840..8956932b403ef 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -14,14 +14,21 @@ #include "WebAssemblyCallLowering.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "Utils/WasmAddressSpaces.h" #include "WebAssemblyISelLowering.h" #include "WebAssemblyMachineFunctionInfo.h" #include "WebAssemblySubtarget.h" #include "WebAssemblyUtilities.h" +#include "llvm/Analysis/MemoryLocation.h" #include "llvm/CodeGen/Analysis.h" +#include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/FunctionLoweringInfo.h" +#include "llvm/CodeGen/GlobalISel/CallLowering.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/LowLevelTypeUtils.h" +#include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/CodeGen/MachineInstrBuilder.h" +#include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGenTypes/LowLevelType.h" #include "llvm/IR/Argument.h" #include "llvm/IR/DataLayout.h" @@ -29,7 +36,10 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/MC/MCSymbolWasm.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include #define DEBUG_TYPE "wasm-call-lowering" @@ -555,7 +565,6 @@ bool WebAssemblyCallLowering::lowerFormalArguments( SmallVector SplitArgs; if (!FLI.CanLowerReturn) { - dbgs() << "grath\n"; insertSRetIncomingArgument(F, SplitArgs, FLI.DemoteRegister, MRI, DL); } unsigned i = 0; @@ -683,5 +692,445 @@ bool WebAssemblyCallLowering::lowerFormalArguments( bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info) const { - return false; + MachineFunction &MF = MIRBuilder.getMF(); + auto DL = MIRBuilder.getDataLayout(); + LLVMContext &Ctx = MIRBuilder.getContext(); + const WebAssemblyTargetLowering &TLI = *getTLI(); + MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); + const WebAssemblySubtarget &Subtarget = MF.getSubtarget(); + + CallingConv::ID CallConv = Info.CallConv; + if (!callingConvSupported(CallConv)) { + fail(MIRBuilder, + "WebAssembly doesn't support language-specific or target-specific " + "calling conventions yet"); + return false; + } + + // TODO: investigate "PatchPoint" + /* + if (Info.IsPatchPoint) { + fail(MIRBuilder, "WebAssembly doesn't support patch point yet"); + return false; + } + */ + + if (Info.IsTailCall) { + Info.LoweredTailCall = true; + auto NoTail = [&](const char *Msg) { + if (Info.CB && Info.CB->isMustTailCall()) + fail(MIRBuilder, Msg); + Info.LoweredTailCall = false; + }; + + if (!Subtarget.hasTailCall()) + NoTail("WebAssembly 'tail-call' feature not enabled"); + + // Varargs calls cannot be tail calls because the buffer is on the stack + if (Info.IsVarArg) + NoTail("WebAssembly does not support varargs tail calls"); + + // Do not tail call unless caller and callee return types match + const Function &F = MF.getFunction(); + const TargetMachine &TM = TLI.getTargetMachine(); + Type *RetTy = F.getReturnType(); + SmallVector CallerRetTys; + SmallVector CalleeRetTys; + computeLegalValueVTs(F, TM, RetTy, CallerRetTys); + computeLegalValueVTs(F, TM, Info.OrigRet.Ty, CalleeRetTys); + bool TypesMatch = CallerRetTys.size() == CalleeRetTys.size() && + std::equal(CallerRetTys.begin(), CallerRetTys.end(), + CalleeRetTys.begin()); + if (!TypesMatch) + NoTail("WebAssembly tail call requires caller and callee return types to " + "match"); + + // If pointers to local stack values are passed, we cannot tail call + if (Info.CB) { + for (auto &Arg : Info.CB->args()) { + Value *Val = Arg.get(); + // Trace the value back through pointer operations + while (true) { + Value *Src = Val->stripPointerCastsAndAliases(); + if (auto *GEP = dyn_cast(Src)) + Src = GEP->getPointerOperand(); + if (Val == Src) + break; + Val = Src; + } + if (isa(Val)) { + NoTail( + "WebAssembly does not support tail calling with stack arguments"); + break; + } + } + } + } + + MachineInstrBuilder CallInst; + + bool IsIndirect = false; + Register IndirectIdx; + + if (Info.Callee.isReg()) { + LLT CalleeType = MRI.getType(Info.Callee.getReg()); + assert(CalleeType.isPointer() && "Trying to lower a call with a Callee other than a pointer???"); + + IsIndirect = true; + CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT : WebAssembly::CALL_INDIRECT); + + // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp + CallInst.addImm(0); + + MCSymbolWasm *Table; + if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_DEFAULT) { + Table = WebAssembly::getOrCreateFunctionTableSymbol( + MF.getContext(), &Subtarget); + IndirectIdx = Info.Callee.getReg(); + + auto PtrSize = CalleeType.getSizeInBits(); + auto PtrIntLLT = LLT::scalar(PtrSize); + + IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0); + if (PtrSize > 32) { + IndirectIdx = MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0); + } + } else if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) { + Table = WebAssembly::getOrCreateFuncrefCallTableSymbol( + MF.getContext(), &Subtarget); + + auto TableSetInstr = MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF); + TableSetInstr.addSym(Table); + TableSetInstr.addUse(Info.Callee.getReg()); + IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0); + } else { + fail(MIRBuilder, "Invalid address space for indirect call"); + return false; + } + + if (Subtarget.hasCallIndirectOverlong()) { + CallInst.addSym(Table); + } else { + // For the MVP there is at most one table whose number is 0, but we can't + // write a table symbol or issue relocations. Instead we just ensure the + // table is live and write a zero. + Table->setNoStrip(); + CallInst.addImm(0); + } + } else { + CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL); + + if (Info.Callee.isGlobal()) { + CallInst.addGlobalAddress(Info.Callee.getGlobal()); + } else if (Info.Callee.isSymbol()) { + // TODO: figure out how to trigger/test this + CallInst.addSym(Info.Callee.getMCSymbol()); + } else { + llvm_unreachable("Trying to lower call with a callee other than reg, global, or a symbol."); + } + } + + + SmallVector SplitArgs; + + bool HasSwiftErrorArg = false; + bool HasSwiftSelfArg = false; + + for (const auto &Arg : Info.OrigArgs) { + HasSwiftSelfArg |= Arg.Flags[0].isSwiftSelf(); + HasSwiftErrorArg |= Arg.Flags[0].isSwiftError(); + if (Arg.Flags[0].isNest()) { + fail(MIRBuilder, "WebAssembly hasn't implemented nest arguments"); + return false; + } + if (Arg.Flags[0].isInAlloca()) { + fail(MIRBuilder, "WebAssembly hasn't implemented inalloca arguments"); + return false; + } + if (Arg.Flags[0].isInConsecutiveRegs()) { + fail(MIRBuilder, "WebAssembly hasn't implemented cons regs arguments"); + return false; + } + if (Arg.Flags[0].isInConsecutiveRegsLast()) { + fail(MIRBuilder, + "WebAssembly hasn't implemented cons regs last arguments"); + return false; + } + + if (Arg.Flags[0].isByVal() && Arg.Flags[0].getByValSize() != 0) { + MachineFrameInfo &MFI = MF.getFrameInfo(); + + unsigned MemSize = Arg.Flags[0].getByValSize(); + Align MemAlign = Arg.Flags[0].getNonZeroByValAlign(); + int FI = MFI.CreateStackObject(Arg.Flags[0].getByValSize(), MemAlign, + /*isSS=*/false); + + auto StackAddrSpace = DL.getAllocaAddrSpace(); + auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL); + Register StackObjPtrVreg = + MF.getRegInfo().createGenericVirtualRegister(PtrLLT); + + MIRBuilder.buildFrameIndex(StackObjPtrVreg, FI); + + MachinePointerInfo DstPtrInfo = MachinePointerInfo::getFixedStack(MF, FI); + + MachinePointerInfo SrcPtrInfo(Arg.OrigValue); + if (!Arg.OrigValue) { + // We still need to accurately track the stack address space if we + // don't know the underlying value. + SrcPtrInfo = MachinePointerInfo::getUnknownStack(MF); + } + + Align DstAlign = + std::max(MemAlign, inferAlignFromPtrInfo(MF, DstPtrInfo)); + + Align SrcAlign = + std::max(MemAlign, inferAlignFromPtrInfo(MF, SrcPtrInfo)); + + MachineMemOperand *SrcMMO = MF.getMachineMemOperand( + SrcPtrInfo, + MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable, + MemSize, SrcAlign); + + MachineMemOperand *DstMMO = MF.getMachineMemOperand( + DstPtrInfo, + MachineMemOperand::MOStore | MachineMemOperand::MODereferenceable, + MemSize, DstAlign); + + const LLT SizeTy = LLT::scalar(PtrLLT.getSizeInBits()); + + auto SizeConst = MIRBuilder.buildConstant(SizeTy, MemSize); + MIRBuilder.buildMemCpy(StackObjPtrVreg, Arg.Regs[0], SizeConst, *DstMMO, + *SrcMMO); + } + + splitToValueTypes(Arg, SplitArgs, DL, CallConv); + } + + unsigned NumFixedArgs = 0; + + for (auto &Arg : SplitArgs) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Arg.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + + // If we need to split the type over multiple regs, check it's a scenario + // we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Arg.Flags[0]; + Arg.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Arg.Flags.push_back(Flags); + } + + Arg.OrigRegs.assign(Arg.Regs.begin(), Arg.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Arg.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + + buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT, + extendOpFromFlags(Arg.Flags[0])); + } + + if (!Arg.Flags[0].isVarArg()) { + for (unsigned Part = 0; Part < NumParts; ++Part) { + CallInst.addUse(Arg.Regs[Part]); + ++NumFixedArgs; + } + } + } + + if (CallConv == CallingConv::Swift) { + Type *PtrTy = PointerType::getUnqual(Ctx); + LLT PtrLLT = getLLTForType(*PtrTy, DL); + + if (!HasSwiftSelfArg) { + CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0)); + } + if (!HasSwiftErrorArg) { + CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0)); + } + } + + // Analyze operands of the call, assigning locations to each operand. + SmallVector ArgLocs; + CCState CCInfo(CallConv, Info.IsVarArg, MF, ArgLocs, Ctx); + + if (Info.IsVarArg) { + // Outgoing non-fixed arguments are placed in a buffer. First + // compute their offsets and the total amount of buffer space needed. + for (ArgInfo &Arg : drop_begin(SplitArgs, NumFixedArgs)) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT PartVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + Type *Ty = EVT(PartVT).getTypeForEVT(Ctx); + + for (unsigned Part = 0; Part < Arg.Regs.size(); ++Part) { + Align Alignment = std::max(Arg.Flags[Part].getNonZeroOrigAlign(), + DL.getABITypeAlign(Ty)); + unsigned Offset = + CCInfo.AllocateStack(DL.getTypeAllocSize(Ty), Alignment); + CCInfo.addLoc(CCValAssign::getMem(ArgLocs.size(), PartVT, Offset, + PartVT, CCValAssign::Full)); + } + } + } + + unsigned NumBytes = CCInfo.getAlignedCallFrameSize(); + + auto StackAddrSpace = DL.getAllocaAddrSpace(); + auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL); + auto SizeLLT = LLT::scalar(PtrLLT.getSizeInBits()); + + if (Info.IsVarArg && NumBytes) { + Register VarArgStackPtr = + MF.getRegInfo().createGenericVirtualRegister(PtrLLT); + + MaybeAlign StackAlign = DL.getStackAlignment(); + assert(StackAlign && "data layout string is missing stack alignment"); + int FI = MF.getFrameInfo().CreateStackObject(NumBytes, *StackAlign, + /*isSS=*/false); + + MIRBuilder.buildFrameIndex(VarArgStackPtr, FI); + + unsigned ValNo = 0; + for (ArgInfo &Arg : drop_begin(SplitArgs, NumFixedArgs)) { + EVT OrigVT = TLI.getValueType(DL, Arg.Ty); + MVT PartVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + Type *Ty = EVT(PartVT).getTypeForEVT(Ctx); + + for (unsigned Part = 0; Part < Arg.Regs.size(); ++Part) { + Align Alignment = std::max(Arg.Flags[Part].getNonZeroOrigAlign(), + DL.getABITypeAlign(Ty)); + + unsigned Offset = ArgLocs[ValNo++].getLocMemOffset(); + + Register DstPtr = + MIRBuilder + .buildPtrAdd(PtrLLT, VarArgStackPtr, + MIRBuilder.buildConstant(SizeLLT, Offset).getReg(0)) + .getReg(0); + + MachineMemOperand *DstMMO = MF.getMachineMemOperand( + MachinePointerInfo::getFixedStack(MF, FI, Offset), + MachineMemOperand::MOStore | MachineMemOperand::MODereferenceable, + PartVT.getStoreSize(), Alignment); + + MIRBuilder.buildStore(Arg.Regs[Part], DstPtr, *DstMMO); + } + } + + CallInst.addUse(VarArgStackPtr); + } else if (Info.IsVarArg) { + CallInst.addUse(MIRBuilder.buildConstant(PtrLLT, 0).getReg(0)); + } + + if (IsIndirect) { + CallInst.addUse(IndirectIdx); + } + + MIRBuilder.insertInstr(CallInst); + + if (Info.LoweredTailCall) { + return true; + } + + if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) { + SmallVector SplitEVTs; + ComputeValueVTs(TLI, DL, Info.OrigRet.Ty, SplitEVTs); + assert(Info.OrigRet.Regs.size() == SplitEVTs.size() && + "For each split Type there should be exactly one VReg."); + + SmallVector SplitReturns; + + unsigned i = 0; + for (auto SplitEVT : SplitEVTs) { + Register CurVReg = Info.OrigRet.Regs[i]; + ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0}; + setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB); + + splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv); + ++i; + } + + for (auto &Ret : SplitReturns) { + EVT OrigVT = TLI.getValueType(DL, Ret.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Ret.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + + // If we need to split the type over multiple regs, check it's a scenario + // we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Ret.Flags[0]; + Ret.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Ret.Flags.push_back(Flags); + } + + Ret.OrigRegs.assign(Ret.Regs.begin(), Ret.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Ret.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT, + Ret.Flags[0]); + } + + for (unsigned Part = 0; Part < NumParts; ++Part) { + CallInst.addDef(Ret.Regs[Part]); + } + } + } + + if (!Info.CanLowerReturn) { + insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs, + Info.DemoteRegister, Info.DemoteStackIndex); + + for (auto Reg : Info.OrigRet.Regs) { + CallInst.addDef(Reg); + } + } + + return true; } From 21d2a504bd6c87f12f5709b9c84d47d9d6094e6a Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:41:11 -0700 Subject: [PATCH 03/17] Fix formatting --- .../GISel/WebAssemblyCallLowering.cpp | 128 ++++++++++-------- 1 file changed, 69 insertions(+), 59 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 8956932b403ef..23a6274e66661 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -697,7 +697,8 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, LLVMContext &Ctx = MIRBuilder.getContext(); const WebAssemblyTargetLowering &TLI = *getTLI(); MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); - const WebAssemblySubtarget &Subtarget = MF.getSubtarget(); + const WebAssemblySubtarget &Subtarget = + MF.getSubtarget(); CallingConv::ID CallConv = Info.CallConv; if (!callingConvSupported(CallConv)) { @@ -716,7 +717,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, */ if (Info.IsTailCall) { - Info.LoweredTailCall = true; + Info.LoweredTailCall = true; auto NoTail = [&](const char *Msg) { if (Info.CB && Info.CB->isMustTailCall()) fail(MIRBuilder, Msg); @@ -773,65 +774,73 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, Register IndirectIdx; if (Info.Callee.isReg()) { - LLT CalleeType = MRI.getType(Info.Callee.getReg()); - assert(CalleeType.isPointer() && "Trying to lower a call with a Callee other than a pointer???"); - - IsIndirect = true; - CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT : WebAssembly::CALL_INDIRECT); - - // Placeholder for the type index. - // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp - CallInst.addImm(0); - - MCSymbolWasm *Table; - if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_DEFAULT) { - Table = WebAssembly::getOrCreateFunctionTableSymbol( - MF.getContext(), &Subtarget); - IndirectIdx = Info.Callee.getReg(); - - auto PtrSize = CalleeType.getSizeInBits(); - auto PtrIntLLT = LLT::scalar(PtrSize); - - IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0); - if (PtrSize > 32) { - IndirectIdx = MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0); - } - } else if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) { - Table = WebAssembly::getOrCreateFuncrefCallTableSymbol( - MF.getContext(), &Subtarget); - - auto TableSetInstr = MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF); - TableSetInstr.addSym(Table); - TableSetInstr.addUse(Info.Callee.getReg()); - IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0); - } else { - fail(MIRBuilder, "Invalid address space for indirect call"); - return false; + LLT CalleeType = MRI.getType(Info.Callee.getReg()); + assert(CalleeType.isPointer() && + "Trying to lower a call with a Callee other than a pointer???"); + + IsIndirect = true; + CallInst = MIRBuilder.buildInstrNoInsert( + Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT + : WebAssembly::CALL_INDIRECT); + + // Placeholder for the type index. + // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp + CallInst.addImm(0); + + MCSymbolWasm *Table; + if (CalleeType.getAddressSpace() == + WebAssembly::WASM_ADDRESS_SPACE_DEFAULT) { + Table = WebAssembly::getOrCreateFunctionTableSymbol(MF.getContext(), + &Subtarget); + IndirectIdx = Info.Callee.getReg(); + + auto PtrSize = CalleeType.getSizeInBits(); + auto PtrIntLLT = LLT::scalar(PtrSize); + + IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0); + if (PtrSize > 32) { + IndirectIdx = + MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0); } + } else if (CalleeType.getAddressSpace() == + WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) { + Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(MF.getContext(), + &Subtarget); + + auto TableSetInstr = + MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF); + TableSetInstr.addSym(Table); + TableSetInstr.addUse(Info.Callee.getReg()); + IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0); + } else { + fail(MIRBuilder, "Invalid address space for indirect call"); + return false; + } - if (Subtarget.hasCallIndirectOverlong()) { - CallInst.addSym(Table); - } else { - // For the MVP there is at most one table whose number is 0, but we can't - // write a table symbol or issue relocations. Instead we just ensure the - // table is live and write a zero. - Table->setNoStrip(); - CallInst.addImm(0); - } + if (Subtarget.hasCallIndirectOverlong()) { + CallInst.addSym(Table); + } else { + // For the MVP there is at most one table whose number is 0, but we can't + // write a table symbol or issue relocations. Instead we just ensure the + // table is live and write a zero. + Table->setNoStrip(); + CallInst.addImm(0); + } } else { - CallInst = MIRBuilder.buildInstrNoInsert(Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL); - - if (Info.Callee.isGlobal()) { - CallInst.addGlobalAddress(Info.Callee.getGlobal()); - } else if (Info.Callee.isSymbol()) { - // TODO: figure out how to trigger/test this - CallInst.addSym(Info.Callee.getMCSymbol()); - } else { - llvm_unreachable("Trying to lower call with a callee other than reg, global, or a symbol."); - } + CallInst = MIRBuilder.buildInstrNoInsert( + Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL); + + if (Info.Callee.isGlobal()) { + CallInst.addGlobalAddress(Info.Callee.getGlobal()); + } else if (Info.Callee.isSymbol()) { + // TODO: figure out how to trigger/test this + CallInst.addSym(Info.Callee.getMCSymbol()); + } else { + llvm_unreachable("Trying to lower call with a callee other than reg, " + "global, or a symbol."); + } } - SmallVector SplitArgs; bool HasSwiftErrorArg = false; @@ -1028,8 +1037,9 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, Register DstPtr = MIRBuilder - .buildPtrAdd(PtrLLT, VarArgStackPtr, - MIRBuilder.buildConstant(SizeLLT, Offset).getReg(0)) + .buildPtrAdd( + PtrLLT, VarArgStackPtr, + MIRBuilder.buildConstant(SizeLLT, Offset).getReg(0)) .getReg(0); MachineMemOperand *DstMMO = MF.getMachineMemOperand( @@ -1053,7 +1063,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, MIRBuilder.insertInstr(CallInst); if (Info.LoweredTailCall) { - return true; + return true; } if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) { From 748f10903d1a1d7da74ffd67edc62daabbc79621 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:41:18 -0700 Subject: [PATCH 04/17] Fix some issues with WebAssemblyCallLowering::lowerCall --- .../GISel/WebAssemblyCallLowering.cpp | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 23a6274e66661..23533c5ad1c75 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -798,10 +798,6 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, auto PtrIntLLT = LLT::scalar(PtrSize); IndirectIdx = MIRBuilder.buildPtrToInt(PtrIntLLT, IndirectIdx).getReg(0); - if (PtrSize > 32) { - IndirectIdx = - MIRBuilder.buildTrunc(LLT::scalar(32), IndirectIdx).getReg(0); - } } else if (CalleeType.getAddressSpace() == WebAssembly::WASM_ADDRESS_SPACE_FUNCREF) { Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(MF.getContext(), @@ -833,8 +829,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, if (Info.Callee.isGlobal()) { CallInst.addGlobalAddress(Info.Callee.getGlobal()); } else if (Info.Callee.isSymbol()) { - // TODO: figure out how to trigger/test this - CallInst.addSym(Info.Callee.getMCSymbol()); + CallInst.addExternalSymbol(Info.Callee.getSymbolName()); } else { llvm_unreachable("Trying to lower call with a callee other than reg, " "global, or a symbol."); @@ -1078,8 +1073,24 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, for (auto SplitEVT : SplitEVTs) { Register CurVReg = Info.OrigRet.Regs[i]; ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0}; - setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB); - + if (Info.CB) { + setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB); + } else { + // we don't have a call base, so chances are we're looking at a libcall + // (external symbol). + + // TODO: figure out how to get ALL the correct attributes + auto &Flags = CurArgInfo.Flags[0]; + PointerType *PtrTy = + dyn_cast(CurArgInfo.Ty->getScalarType()); + if (PtrTy) { + Flags.setPointer(); + Flags.setPointerAddrSpace(PtrTy->getPointerAddressSpace()); + } + Align MemAlign = DL.getABITypeAlign(CurArgInfo.Ty); + Flags.setMemAlign(MemAlign); + Flags.setOrigAlign(MemAlign); + } splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv); ++i; } From 8abb77c98fb5341c49fea150fdfdbf5294720c49 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:41:44 -0700 Subject: [PATCH 05/17] Attempt to make CallLowering floating-point aware (use FPEXT and FPTRUNC instead of integer ANYEXT/TRUNC) --- .../GISel/WebAssemblyCallLowering.cpp | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 23533c5ad1c75..3dd928a825995 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -29,6 +29,7 @@ #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineMemOperand.h" +#include "llvm/CodeGen/TargetOpcodes.h" #include "llvm/CodeGenTypes/LowLevelType.h" #include "llvm/IR/Argument.h" #include "llvm/IR/DataLayout.h" @@ -108,9 +109,12 @@ mergeVectorRegsToResultRegs(MachineIRBuilder &B, ArrayRef DstRegs, /// typed values to the original IR value. \p OrigRegs contains the destination /// value registers of type \p LLTy, and \p Regs contains the legalized pieces /// with type \p PartLLT. This is used for incoming values (physregs to vregs). + +// Modified to account for floating-point extends/truncations static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef OrigRegs, ArrayRef Regs, LLT LLTy, LLT PartLLT, - const ISD::ArgFlagsTy Flags) { + const ISD::ArgFlagsTy Flags, + bool IsFloatingPoint) { MachineRegisterInfo &MRI = *B.getMRI(); if (PartLLT == LLTy) { @@ -153,7 +157,10 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef OrigRegs, return; } - B.buildTrunc(OrigRegs[0], SrcReg); + if (IsFloatingPoint) + B.buildFPTrunc(OrigRegs[0], SrcReg); + else + B.buildTrunc(OrigRegs[0], SrcReg); return; } @@ -166,7 +173,11 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef OrigRegs, B.buildMergeValues(OrigRegs[0], Regs); else { auto Widened = B.buildMergeLikeInstr(LLT::scalar(SrcSize), Regs); - B.buildTrunc(OrigRegs[0], Widened); + + if (IsFloatingPoint) + B.buildFPTrunc(OrigRegs[0], Widened); + else + B.buildTrunc(OrigRegs[0], Widened); } return; @@ -496,7 +507,9 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); } buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT, - extendOpFromFlags(Arg.Flags[0])); + Arg.Ty->isFloatingPointTy() + ? TargetOpcode::G_FPEXT + : extendOpFromFlags(Arg.Flags[0])); } for (unsigned Part = 0; Part < NumParts; ++Part) { @@ -630,7 +643,7 @@ bool WebAssemblyCallLowering::lowerFormalArguments( Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); } buildCopyFromRegs(MIRBuilder, Arg.OrigRegs, Arg.Regs, OrigLLT, NewLLT, - Arg.Flags[0]); + Arg.Flags[0], Arg.Ty->isFloatingPointTy()); } for (unsigned Part = 0; Part < NumParts; ++Part) { @@ -955,7 +968,9 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, } buildCopyToRegs(MIRBuilder, Arg.Regs, Arg.OrigRegs[0], OrigLLT, NewLLT, - extendOpFromFlags(Arg.Flags[0])); + Arg.Ty->isFloatingPointTy() + ? TargetOpcode::G_FPEXT + : extendOpFromFlags(Arg.Flags[0])); } if (!Arg.Flags[0].isVarArg()) { @@ -1135,7 +1150,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); } buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT, - Ret.Flags[0]); + Ret.Flags[0], Ret.Ty->isFloatingPointTy()); } for (unsigned Part = 0; Part < NumParts; ++Part) { From 721a62827f5dbffacd38011b23b336feae01804c Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:41:53 -0700 Subject: [PATCH 06/17] Fix lowerCall vararg crash. --- llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 3dd928a825995..7ee118387bfe4 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -976,8 +976,8 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, if (!Arg.Flags[0].isVarArg()) { for (unsigned Part = 0; Part < NumParts; ++Part) { CallInst.addUse(Arg.Regs[Part]); - ++NumFixedArgs; } + ++NumFixedArgs; } } From f1b4a0507d33e376f3ce2cf798dc0d74ca241d60 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:42:03 -0700 Subject: [PATCH 07/17] Set up basic legalization (scalar only, limited support for FP, p0 only) --- .../GISel/WebAssemblyLegalizerInfo.cpp | 256 ++++++++++++++++++ .../GISel/WebAssemblyLegalizerInfo.h | 2 + 2 files changed, 258 insertions(+) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index 3acdabb5612cc..c6cd1c5b371e9 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -11,6 +11,12 @@ //===----------------------------------------------------------------------===// #include "WebAssemblyLegalizerInfo.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "WebAssemblySubtarget.h" +#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/DerivedTypes.h" #define DEBUG_TYPE "wasm-legalinfo" @@ -19,5 +25,255 @@ using namespace LegalizeActions; WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( const WebAssemblySubtarget &ST) { + using namespace TargetOpcode; + const LLT s8 = LLT::scalar(8); + const LLT s16 = LLT::scalar(16); + const LLT s32 = LLT::scalar(32); + const LLT s64 = LLT::scalar(64); + + const LLT p0 = LLT::pointer(0, ST.hasAddr64() ? 64 : 32); + const LLT p0s = LLT::scalar(ST.hasAddr64() ? 64 : 32); + + getActionDefinitionsBuilder(G_GLOBAL_VALUE).legalFor({p0}); + + getActionDefinitionsBuilder(G_PHI) + .legalFor({p0, s32, s64}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + getActionDefinitionsBuilder(G_BR).alwaysLegal(); + getActionDefinitionsBuilder(G_BRCOND).legalFor({s32}).clampScalar(0, s32, + s32); + getActionDefinitionsBuilder(G_BRJT) + .legalFor({{p0, s32}}) + .clampScalar(1, s32, s32); + + getActionDefinitionsBuilder(G_SELECT) + .legalFor({{s32, s32}, {s64, s32}, {p0, s32}}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s32); + + getActionDefinitionsBuilder(G_JUMP_TABLE).legalFor({p0}); + + getActionDefinitionsBuilder(G_ICMP) + .legalFor({{s32, s32}, {s32, s64}, {s32, p0}}) + .widenScalarToNextPow2(1) + .clampScalar(1, s32, s64) + .clampScalar(0, s32, s32); + + getActionDefinitionsBuilder(G_FCMP) + .legalFor({{s32, s32}, {s32, s64}}) + .clampScalar(0, s32, s32) + .libcall(); + + getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); + + getActionDefinitionsBuilder(G_CONSTANT) + .legalFor({s32, s64, p0}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder(G_FCONSTANT) + .legalFor({s32, s64}) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder(G_IMPLICIT_DEF) + .legalFor({s32, s64, p0}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder( + {G_ADD, G_SUB, G_MUL, G_UDIV, G_SDIV, G_UREM, G_SREM}) + .legalFor({s32, s64}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF, + G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_FSHL, + G_FSHR}) + .legalFor({{s32, s32}, {s64, s64}}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64) + .minScalarSameAs(1, 0) + .maxScalarSameAs(1, 0); + + getActionDefinitionsBuilder({G_SCMP, G_UCMP}).lower(); + + getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) + .legalFor({s32, s64}) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder({G_UMIN, G_UMAX, G_SMIN, G_SMAX}).lower(); + + getActionDefinitionsBuilder({G_FADD, G_FSUB, G_FDIV, G_FMUL, G_FNEG, G_FABS, + G_FCEIL, G_FFLOOR, G_FSQRT, G_INTRINSIC_TRUNC, + G_FNEARBYINT, G_FRINT, G_INTRINSIC_ROUNDEVEN, + G_FMINIMUM, G_FMAXIMUM}) + .legalFor({s32, s64}) + .minScalar(0, s32); + + // TODO: _IEEE not lowering correctly? + getActionDefinitionsBuilder( + {G_FMINNUM, G_FMAXNUM, G_FMINNUM_IEEE, G_FMAXNUM_IEEE}) + .lowerFor({s32, s64}) + .minScalar(0, s32); + + getActionDefinitionsBuilder({G_FMA, G_FREM}) + .libcallFor({s32, s64}) + .minScalar(0, s32); + + getActionDefinitionsBuilder(G_FCOPYSIGN) + .legalFor({s32, s64}) + .minScalar(0, s32) + .minScalarSameAs(1, 0) + .maxScalarSameAs(1, 0); + + getActionDefinitionsBuilder({G_FPTOUI, G_FPTOUI_SAT, G_FPTOSI, G_FPTOSI_SAT}) + .legalForCartesianProduct({s32, s64}, {s32, s64}) + .minScalar(1, s32) + .widenScalarToNextPow2(0) + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder({G_UITOFP, G_SITOFP}) + .legalForCartesianProduct({s32, s64}, {s32, s64}) + .minScalar(1, s32) + .widenScalarToNextPow2(1) + .clampScalar(1, s32, s64); + + getActionDefinitionsBuilder(G_PTRTOINT).legalFor({{p0s, p0}}); + getActionDefinitionsBuilder(G_INTTOPTR).legalFor({{p0, p0s}}); + getActionDefinitionsBuilder(G_PTR_ADD).legalFor({{p0, p0s}}); + + getActionDefinitionsBuilder(G_LOAD) + .legalForTypesWithMemDesc( + {{s32, p0, s32, 1}, {s64, p0, s64, 1}, {p0, p0, p0, 1}}) + .legalForTypesWithMemDesc({{s32, p0, s8, 1}, + {s32, p0, s16, 1}, + + {s64, p0, s8, 1}, + {s64, p0, s16, 1}, + {s64, p0, s32, 1}}) + .widenScalarToNextPow2(0) + .lowerIfMemSizeNotByteSizePow2() + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder(G_STORE) + .legalForTypesWithMemDesc( + {{s32, p0, s32, 1}, {s64, p0, s64, 1}, {p0, p0, p0, 1}}) + .legalForTypesWithMemDesc({{s32, p0, s8, 1}, + {s32, p0, s16, 1}, + + {s64, p0, s8, 1}, + {s64, p0, s16, 1}, + {s64, p0, s32, 1}}) + .widenScalarToNextPow2(0) + .lowerIfMemSizeNotByteSizePow2() + .clampScalar(0, s32, s64); + + getActionDefinitionsBuilder({G_ZEXTLOAD, G_SEXTLOAD}) + .legalForTypesWithMemDesc({{s32, p0, s8, 1}, + {s32, p0, s16, 1}, + + {s64, p0, s8, 1}, + {s64, p0, s16, 1}, + {s64, p0, s32, 1}}) + .widenScalarToNextPow2(0) + .lowerIfMemSizeNotByteSizePow2() + .clampScalar(0, s32, s64) + .lower(); + + if (ST.hasBulkMemoryOpt()) { + getActionDefinitionsBuilder(G_BZERO).unsupported(); + + getActionDefinitionsBuilder(G_MEMSET) + .legalForCartesianProduct({p0}, {s32}, {p0s}) + .customForCartesianProduct({p0}, {s8}, {p0s}) + .immIdx(0); + + getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) + .legalForCartesianProduct({p0}, {p0}, {p0s}) + .immIdx(0); + + getActionDefinitionsBuilder(G_MEMCPY_INLINE) + .legalForCartesianProduct({p0}, {p0}, {p0s}); + } else { + getActionDefinitionsBuilder({G_BZERO, G_MEMCPY, G_MEMMOVE, G_MEMSET}) + .libcall(); + } + + // TODO: figure out how to combine G_ANYEXT of G_ASSERT_{S|Z}EXT (or + // appropriate G_AND and G_SEXT_IN_REG?) to a G_{S|Z}EXT + G_ASSERT_{S|Z}EXT + // for better optimization (since G_ANYEXT lowers to a ZEXT or SEXT + // instruction anyway). + + getActionDefinitionsBuilder(G_ANYEXT) + .legalFor({{s64, s32}}) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s64); + + getActionDefinitionsBuilder({G_SEXT, G_ZEXT}) + .legalFor({{s64, s32}}) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s64) + .lower(); + + if (ST.hasSignExt()) { + getActionDefinitionsBuilder(G_SEXT_INREG) + .clampScalar(0, s32, s64) + .customFor({s32, s64}) + .lower(); + } else { + getActionDefinitionsBuilder(G_SEXT_INREG).lower(); + } + + getActionDefinitionsBuilder(G_TRUNC) + .legalFor({{s32, s64}}) + .clampScalar(0, s32, s64) + .clampScalar(1, s32, s64) + .lower(); + + getActionDefinitionsBuilder(G_FPEXT).legalFor({{s64, s32}}); + + getActionDefinitionsBuilder(G_FPTRUNC).legalFor({{s32, s64}}); + + getActionDefinitionsBuilder(G_VASTART).legalFor({p0}); + getActionDefinitionsBuilder(G_VAARG) + .legalForCartesianProduct({s32, s64}, {p0}) + .clampScalar(0, s32, s64); + getLegacyLegalizerInfo().computeTables(); } + +bool WebAssemblyLegalizerInfo::legalizeCustom( + LegalizerHelper &Helper, MachineInstr &MI, + LostDebugLocObserver &LocObserver) const { + switch (MI.getOpcode()) { + case TargetOpcode::G_SEXT_INREG: { + // Mark only 8/16/32-bit SEXT_INREG as legal + auto [DstType, SrcType] = MI.getFirst2LLTs(); + auto ExtFromWidth = MI.getOperand(2).getImm(); + + if (ExtFromWidth == 8 || ExtFromWidth == 16 || + (DstType.getScalarSizeInBits() == 64 && ExtFromWidth == 32)) { + return true; + } + return false; + } + case TargetOpcode::G_MEMSET: { + // Anyext the value being set to 32 bit (only the bottom 8 bits are read by + // the instruction). + Helper.Observer.changingInstr(MI); + auto &Value = MI.getOperand(1); + + Register ExtValueReg = + Helper.MIRBuilder.buildAnyExt(LLT::scalar(32), Value).getReg(0); + Value.setReg(ExtValueReg); + Helper.Observer.changedInstr(MI); + return true; + } + default: + break; + } + return false; +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h index c02205fc7ae0d..5aca23c9514e1 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.h @@ -24,6 +24,8 @@ class WebAssemblySubtarget; class WebAssemblyLegalizerInfo : public LegalizerInfo { public: WebAssemblyLegalizerInfo(const WebAssemblySubtarget &ST); + + bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const override; }; } // namespace llvm #endif From fa779c4b9a1a5cff2a731aefe7531983c7bd0092 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sun, 28 Sep 2025 22:42:32 -0700 Subject: [PATCH 08/17] start on regbankselect --- llvm/lib/Target/WebAssembly/CMakeLists.txt | 1 + .../GISel/WebAssemblyCallLowering.cpp | 299 ++++++++++------- .../GISel/WebAssemblyRegisterBankInfo.cpp | 302 ++++++++++++++++++ .../GISel/WebAssemblyRegisterBankInfo.h | 40 +++ llvm/lib/Target/WebAssembly/WebAssembly.td | 1 + .../WebAssembly/WebAssemblyRegisterBanks.td | 20 ++ .../WebAssembly/WebAssemblySubtarget.cpp | 4 +- 7 files changed, 549 insertions(+), 118 deletions(-) create mode 100644 llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index ffb4ad182c81b..e0be3d6af2f0c 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -9,6 +9,7 @@ tablegen(LLVM WebAssemblyGenDisassemblerTables.inc -gen-disassembler) tablegen(LLVM WebAssemblyGenFastISel.inc -gen-fast-isel) tablegen(LLVM WebAssemblyGenInstrInfo.inc -gen-instr-info) tablegen(LLVM WebAssemblyGenMCCodeEmitter.inc -gen-emitter) +tablegen(LLVM WebAssemblyGenRegisterBank.inc -gen-register-bank) tablegen(LLVM WebAssemblyGenRegisterInfo.inc -gen-register-info) tablegen(LLVM WebAssemblyGenSDNodeInfo.inc -gen-sd-node-info) tablegen(LLVM WebAssemblyGenSubtargetInfo.inc -gen-subtarget) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 7ee118387bfe4..7f3c7e62ce02d 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -13,10 +13,12 @@ //===----------------------------------------------------------------------===// #include "WebAssemblyCallLowering.h" +#include "GISel/WebAssemblyRegisterBankInfo.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" #include "Utils/WasmAddressSpaces.h" #include "WebAssemblyISelLowering.h" #include "WebAssemblyMachineFunctionInfo.h" +#include "WebAssemblyRegisterInfo.h" #include "WebAssemblySubtarget.h" #include "WebAssemblyUtilities.h" #include "llvm/Analysis/MemoryLocation.h" @@ -25,6 +27,7 @@ #include "llvm/CodeGen/FunctionLoweringInfo.h" #include "llvm/CodeGen/GlobalISel/CallLowering.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/GlobalISel/Utils.h" #include "llvm/CodeGen/LowLevelTypeUtils.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineInstrBuilder.h" @@ -435,6 +438,12 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, FunctionLoweringInfo &FLI, Register SwiftErrorVReg) const { auto MIB = MIRBuilder.buildInstrNoInsert(WebAssembly::RETURN); + MachineFunction &MF = MIRBuilder.getMF(); + auto &TLI = *getTLI(); + auto &Subtarget = MF.getSubtarget(); + auto &TRI = *Subtarget.getRegisterInfo(); + auto &TII = *Subtarget.getInstrInfo(); + auto &RBI = *Subtarget.getRegBankInfo(); assert(((Val && !VRegs.empty()) || (!Val && VRegs.empty())) && "Return value without a vreg"); @@ -513,7 +522,11 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, } for (unsigned Part = 0; Part < NumParts; ++Part) { - MIB.addUse(Arg.Regs[Part]); + auto NewOutReg = constrainRegToClass(MRI, TII, RBI, Arg.Regs[Part], + *TLI.getRegClassFor(NewVT)); + if (NewOutReg != Arg.Regs[Part]) + MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]); + MIB.addUse(NewOutReg); } } } @@ -564,6 +577,11 @@ bool WebAssemblyCallLowering::lowerFormalArguments( WebAssemblyFunctionInfo *MFI = MF.getInfo(); const DataLayout &DL = F.getDataLayout(); auto &TLI = *getTLI(); + auto &Subtarget = MF.getSubtarget(); + auto &TRI = *Subtarget.getRegisterInfo(); + auto &TII = *Subtarget.getInstrInfo(); + auto &RBI = *Subtarget.getRegBankInfo(); + LLVMContext &Ctx = MIRBuilder.getContext(); const CallingConv::ID CallConv = F.getCallingConv(); @@ -647,9 +665,12 @@ bool WebAssemblyCallLowering::lowerFormalArguments( } for (unsigned Part = 0; Part < NumParts; ++Part) { - MIRBuilder.buildInstr(getWASMArgOpcode(NewVT)) - .addDef(Arg.Regs[Part]) - .addImm(FinalArgIdx); + auto ArgInst = MIRBuilder.buildInstr(getWASMArgOpcode(NewVT)) + .addDef(Arg.Regs[Part]) + .addImm(FinalArgIdx); + + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *ArgInst, + ArgInst->getDesc(), ArgInst->getOperand(0), 0); MFI->addParam(NewVT); ++FinalArgIdx; } @@ -712,6 +733,9 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, MachineRegisterInfo &MRI = *MIRBuilder.getMRI(); const WebAssemblySubtarget &Subtarget = MF.getSubtarget(); + auto &TRI = *Subtarget.getRegisterInfo(); + auto &TII = *Subtarget.getInstrInfo(); + auto &RBI = *Subtarget.getRegBankInfo(); CallingConv::ID CallConv = Info.CallConv; if (!callingConvSupported(CallConv)) { @@ -781,21 +805,128 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, } } + if (Info.LoweredTailCall) { + MF.getFrameInfo().setHasTailCall(); + } + MachineInstrBuilder CallInst; bool IsIndirect = false; Register IndirectIdx; + if (Info.Callee.isReg()) { + IsIndirect = true; + CallInst = MIRBuilder.buildInstr(Info.LoweredTailCall + ? WebAssembly::RET_CALL_INDIRECT + : WebAssembly::CALL_INDIRECT); + } else { + CallInst = MIRBuilder.buildInstr( + Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL); + } + + if (!Info.LoweredTailCall) { + if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) { + SmallVector SplitEVTs; + ComputeValueVTs(TLI, DL, Info.OrigRet.Ty, SplitEVTs); + assert(Info.OrigRet.Regs.size() == SplitEVTs.size() && + "For each split Type there should be exactly one VReg."); + + SmallVector SplitReturns; + + unsigned i = 0; + for (auto SplitEVT : SplitEVTs) { + Register CurVReg = Info.OrigRet.Regs[i]; + ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0}; + if (Info.CB) { + setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB); + } else { + // we don't have a call base, so chances are we're looking at a + // libcall (external symbol). + + // TODO: figure out how to get ALL the correct attributes + auto &Flags = CurArgInfo.Flags[0]; + PointerType *PtrTy = + dyn_cast(CurArgInfo.Ty->getScalarType()); + if (PtrTy) { + Flags.setPointer(); + Flags.setPointerAddrSpace(PtrTy->getPointerAddressSpace()); + } + Align MemAlign = DL.getABITypeAlign(CurArgInfo.Ty); + Flags.setMemAlign(MemAlign); + Flags.setOrigAlign(MemAlign); + } + splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv); + ++i; + } + + for (auto &Ret : SplitReturns) { + EVT OrigVT = TLI.getValueType(DL, Ret.Ty); + MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); + LLT OrigLLT = getLLTForType(*Ret.Ty, DL); + LLT NewLLT = getLLTForMVT(NewVT); + + // If we need to split the type over multiple regs, check it's a + // scenario we currently support. + unsigned NumParts = + TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); + + ISD::ArgFlagsTy OrigFlags = Ret.Flags[0]; + Ret.Flags.clear(); + + for (unsigned Part = 0; Part < NumParts; ++Part) { + ISD::ArgFlagsTy Flags = OrigFlags; + if (Part == 0) { + Flags.setSplit(); + } else { + Flags.setOrigAlign(Align(1)); + if (Part == NumParts - 1) + Flags.setSplitEnd(); + } + + Ret.Flags.push_back(Flags); + } + + Ret.OrigRegs.assign(Ret.Regs.begin(), Ret.Regs.end()); + if (NumParts != 1 || OrigVT != NewVT) { + // If we can't directly assign the register, we need one or more + // intermediate values. + Ret.Regs.resize(NumParts); + + // For each split register, create and assign a vreg that will store + // the incoming component of the larger value. These will later be + // merged to form the final vreg. + for (unsigned Part = 0; Part < NumParts; ++Part) { + Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); + } + buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT, + Ret.Flags[0], Ret.Ty->isFloatingPointTy()); + } + + for (unsigned Part = 0; Part < NumParts; ++Part) { + // MRI.setRegClass(Ret.Regs[Part], TLI.getRegClassFor(NewVT)); + auto NewRetReg = constrainRegToClass(MRI, TII, RBI, Ret.Regs[Part], + *TLI.getRegClassFor(NewVT)); + if (Ret.Regs[Part] != NewRetReg) + MIRBuilder.buildCopy(NewRetReg, Ret.Regs[Part]); + + CallInst.addDef(Ret.Regs[Part]); + } + } + } + + if (!Info.CanLowerReturn) { + insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs, + Info.DemoteRegister, Info.DemoteStackIndex); + } + } + auto SavedInsertPt = MIRBuilder.getInsertPt(); + MIRBuilder.setInstr(*CallInst); + if (Info.Callee.isReg()) { LLT CalleeType = MRI.getType(Info.Callee.getReg()); assert(CalleeType.isPointer() && "Trying to lower a call with a Callee other than a pointer???"); - IsIndirect = true; - CallInst = MIRBuilder.buildInstrNoInsert( - Info.LoweredTailCall ? WebAssembly::RET_CALL_INDIRECT - : WebAssembly::CALL_INDIRECT); - // Placeholder for the type index. // This gets replaced with the correct value in WebAssemblyMCInstLower.cpp CallInst.addImm(0); @@ -816,11 +947,25 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, Table = WebAssembly::getOrCreateFuncrefCallTableSymbol(MF.getContext(), &Subtarget); + Type *PtrTy = PointerType::getUnqual(Ctx); + LLT PtrLLT = getLLTForType(*PtrTy, DL); + auto PtrIntLLT = LLT::scalar(PtrLLT.getSizeInBits()); + + IndirectIdx = MIRBuilder.buildConstant(PtrIntLLT, 0).getReg(0); + auto TableSetInstr = MIRBuilder.buildInstr(WebAssembly::TABLE_SET_FUNCREF); TableSetInstr.addSym(Table); + TableSetInstr.addUse(IndirectIdx); TableSetInstr.addUse(Info.Callee.getReg()); - IndirectIdx = MIRBuilder.buildConstant(LLT::scalar(32), 0).getReg(0); + + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *TableSetInstr, + TableSetInstr->getDesc(), + TableSetInstr->getOperand(1), 1); + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *TableSetInstr, + TableSetInstr->getDesc(), + TableSetInstr->getOperand(2), 2); + } else { fail(MIRBuilder, "Invalid address space for indirect call"); return false; @@ -836,9 +981,6 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, CallInst.addImm(0); } } else { - CallInst = MIRBuilder.buildInstrNoInsert( - Info.LoweredTailCall ? WebAssembly::RET_CALL : WebAssembly::CALL); - if (Info.Callee.isGlobal()) { CallInst.addGlobalAddress(Info.Callee.getGlobal()); } else if (Info.Callee.isSymbol()) { @@ -884,9 +1026,13 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, /*isSS=*/false); auto StackAddrSpace = DL.getAllocaAddrSpace(); - auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL); + auto PtrLLT = + LLT::pointer(StackAddrSpace, DL.getPointerSizeInBits(StackAddrSpace)); + Register StackObjPtrVreg = MF.getRegInfo().createGenericVirtualRegister(PtrLLT); + MRI.setRegClass(StackObjPtrVreg, TLI.getRepRegClassFor(TLI.getPointerTy( + DL, StackAddrSpace))); MIRBuilder.buildFrameIndex(StackObjPtrVreg, FI); @@ -975,6 +1121,10 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, if (!Arg.Flags[0].isVarArg()) { for (unsigned Part = 0; Part < NumParts; ++Part) { + auto NewArgReg = constrainRegToClass(MRI, TII, RBI, Arg.Regs[Part], + *TLI.getRegClassFor(NewVT)); + if (Arg.Regs[Part] != NewArgReg) + MIRBuilder.buildCopy(NewArgReg, Arg.Regs[Part]); CallInst.addUse(Arg.Regs[Part]); } ++NumFixedArgs; @@ -984,12 +1134,17 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, if (CallConv == CallingConv::Swift) { Type *PtrTy = PointerType::getUnqual(Ctx); LLT PtrLLT = getLLTForType(*PtrTy, DL); + auto &PtrRegClass = *TLI.getRegClassFor(TLI.getSimpleValueType(DL, PtrTy)); if (!HasSwiftSelfArg) { - CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0)); + auto NewUndefReg = MIRBuilder.buildUndef(PtrLLT).getReg(0); + MRI.setRegClass(NewUndefReg, &PtrRegClass); + CallInst.addUse(NewUndefReg); } if (!HasSwiftErrorArg) { - CallInst.addUse(MIRBuilder.buildUndef(PtrLLT).getReg(0)); + auto NewUndefReg = MIRBuilder.buildUndef(PtrLLT).getReg(0); + MRI.setRegClass(NewUndefReg, &PtrRegClass); + CallInst.addUse(NewUndefReg); } } @@ -1019,12 +1174,14 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, unsigned NumBytes = CCInfo.getAlignedCallFrameSize(); auto StackAddrSpace = DL.getAllocaAddrSpace(); - auto PtrLLT = getLLTForType(*PointerType::get(Ctx, StackAddrSpace), DL); - auto SizeLLT = LLT::scalar(PtrLLT.getSizeInBits()); + auto PtrLLT = LLT::pointer(StackAddrSpace, DL.getPointerSizeInBits(0)); + auto SizeLLT = LLT::scalar(DL.getPointerSizeInBits(StackAddrSpace)); + auto *PtrRegClass = TLI.getRegClassFor(TLI.getPointerTy(DL, StackAddrSpace)); if (Info.IsVarArg && NumBytes) { Register VarArgStackPtr = MF.getRegInfo().createGenericVirtualRegister(PtrLLT); + MRI.setRegClass(VarArgStackPtr, PtrRegClass); MaybeAlign StackAlign = DL.getStackAlignment(); assert(StackAlign && "data layout string is missing stack alignment"); @@ -1063,110 +1220,20 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, CallInst.addUse(VarArgStackPtr); } else if (Info.IsVarArg) { - CallInst.addUse(MIRBuilder.buildConstant(PtrLLT, 0).getReg(0)); + auto NewArgReg = MIRBuilder.buildConstant(PtrLLT, 0).getReg(0); + MRI.setRegClass(NewArgReg, PtrRegClass); + CallInst.addUse(NewArgReg); } if (IsIndirect) { + auto NewArgReg = + constrainRegToClass(MRI, TII, RBI, IndirectIdx, *PtrRegClass); + if (IndirectIdx != NewArgReg) + MIRBuilder.buildCopy(NewArgReg, IndirectIdx); CallInst.addUse(IndirectIdx); } - MIRBuilder.insertInstr(CallInst); - - if (Info.LoweredTailCall) { - return true; - } - - if (Info.CanLowerReturn && !Info.OrigRet.Ty->isVoidTy()) { - SmallVector SplitEVTs; - ComputeValueVTs(TLI, DL, Info.OrigRet.Ty, SplitEVTs); - assert(Info.OrigRet.Regs.size() == SplitEVTs.size() && - "For each split Type there should be exactly one VReg."); - - SmallVector SplitReturns; - - unsigned i = 0; - for (auto SplitEVT : SplitEVTs) { - Register CurVReg = Info.OrigRet.Regs[i]; - ArgInfo CurArgInfo = ArgInfo{CurVReg, SplitEVT.getTypeForEVT(Ctx), 0}; - if (Info.CB) { - setArgFlags(CurArgInfo, AttributeList::ReturnIndex, DL, *Info.CB); - } else { - // we don't have a call base, so chances are we're looking at a libcall - // (external symbol). - - // TODO: figure out how to get ALL the correct attributes - auto &Flags = CurArgInfo.Flags[0]; - PointerType *PtrTy = - dyn_cast(CurArgInfo.Ty->getScalarType()); - if (PtrTy) { - Flags.setPointer(); - Flags.setPointerAddrSpace(PtrTy->getPointerAddressSpace()); - } - Align MemAlign = DL.getABITypeAlign(CurArgInfo.Ty); - Flags.setMemAlign(MemAlign); - Flags.setOrigAlign(MemAlign); - } - splitToValueTypes(CurArgInfo, SplitReturns, DL, CallConv); - ++i; - } - - for (auto &Ret : SplitReturns) { - EVT OrigVT = TLI.getValueType(DL, Ret.Ty); - MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); - LLT OrigLLT = getLLTForType(*Ret.Ty, DL); - LLT NewLLT = getLLTForMVT(NewVT); - - // If we need to split the type over multiple regs, check it's a scenario - // we currently support. - unsigned NumParts = - TLI.getNumRegistersForCallingConv(Ctx, CallConv, OrigVT); - - ISD::ArgFlagsTy OrigFlags = Ret.Flags[0]; - Ret.Flags.clear(); - - for (unsigned Part = 0; Part < NumParts; ++Part) { - ISD::ArgFlagsTy Flags = OrigFlags; - if (Part == 0) { - Flags.setSplit(); - } else { - Flags.setOrigAlign(Align(1)); - if (Part == NumParts - 1) - Flags.setSplitEnd(); - } - - Ret.Flags.push_back(Flags); - } - - Ret.OrigRegs.assign(Ret.Regs.begin(), Ret.Regs.end()); - if (NumParts != 1 || OrigVT != NewVT) { - // If we can't directly assign the register, we need one or more - // intermediate values. - Ret.Regs.resize(NumParts); - - // For each split register, create and assign a vreg that will store - // the incoming component of the larger value. These will later be - // merged to form the final vreg. - for (unsigned Part = 0; Part < NumParts; ++Part) { - Ret.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); - } - buildCopyFromRegs(MIRBuilder, Ret.OrigRegs, Ret.Regs, OrigLLT, NewLLT, - Ret.Flags[0], Ret.Ty->isFloatingPointTy()); - } - - for (unsigned Part = 0; Part < NumParts; ++Part) { - CallInst.addDef(Ret.Regs[Part]); - } - } - } - - if (!Info.CanLowerReturn) { - insertSRetLoads(MIRBuilder, Info.OrigRet.Ty, Info.OrigRet.Regs, - Info.DemoteRegister, Info.DemoteStackIndex); - - for (auto Reg : Info.OrigRet.Regs) { - CallInst.addDef(Reg); - } - } + MIRBuilder.setInsertPt(MIRBuilder.getMBB(), SavedInsertPt); return true; } diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp index e69de29bb2d1d..e605c46aece85 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp @@ -0,0 +1,302 @@ +#include "WebAssemblyRegisterBankInfo.h" +#include "WebAssemblySubtarget.h" +#include "WebAssemblyTargetMachine.h" +#include "llvm/CodeGen/TargetOpcodes.h" + +#define GET_TARGET_REGBANK_IMPL + +#include "WebAssemblyGenRegisterBank.inc" + +namespace llvm { +namespace WebAssembly { +enum PartialMappingIdx { + PMI_None = -1, + PMI_I32 = 1, + PMI_I64, + PMI_F32, + PMI_F64, + PMI_Min = PMI_I32, +}; + +enum ValueMappingIdx { + InvalidIdx = 0, + I32Idx = 1, + I64Idx = 4, + F32Idx = 7, + F64Idx = 10 +}; + +const RegisterBankInfo::PartialMapping PartMappings[]{{0, 32, I32RegBank}, + {0, 64, I64RegBank}, + {0, 32, F32RegBank}, + {0, 64, F64RegBank}}; + +const RegisterBankInfo::ValueMapping ValueMappings[] = { + // invalid + {nullptr, 0}, + // up to 3 operands as I32 + {&PartMappings[PMI_I32 - PMI_Min], 1}, + {&PartMappings[PMI_I32 - PMI_Min], 1}, + {&PartMappings[PMI_I32 - PMI_Min], 1}, + // up to 3 operands as I64 + {&PartMappings[PMI_I64 - PMI_Min], 1}, + {&PartMappings[PMI_I64 - PMI_Min], 1}, + {&PartMappings[PMI_I64 - PMI_Min], 1}, + // up to 3 operands as F32 + {&PartMappings[PMI_F32 - PMI_Min], 1}, + {&PartMappings[PMI_F32 - PMI_Min], 1}, + {&PartMappings[PMI_F32 - PMI_Min], 1}, + // up to 3 operands as F64 + {&PartMappings[PMI_F64 - PMI_Min], 1}, + {&PartMappings[PMI_F64 - PMI_Min], 1}, + {&PartMappings[PMI_F64 - PMI_Min], 1}}; + +} // namespace WebAssembly +} // namespace llvm + +using namespace llvm; + +WebAssemblyRegisterBankInfo::WebAssemblyRegisterBankInfo( + const TargetRegisterInfo &TRI) {} + +// Instructions where use operands are floating point registers. +// Def operands are general purpose. +static bool isFloatingPointOpcodeUse(unsigned Opc) { + switch (Opc) { + case TargetOpcode::G_FPTOSI: + case TargetOpcode::G_FPTOUI: + case TargetOpcode::G_FCMP: + return true; + default: + return isPreISelGenericFloatingPointOpcode(Opc); + } +} + +// Instructions where def operands are floating point registers. +// Use operands are general purpose. +static bool isFloatingPointOpcodeDef(unsigned Opc) { + switch (Opc) { + case TargetOpcode::G_SITOFP: + case TargetOpcode::G_UITOFP: + return true; + default: + return isPreISelGenericFloatingPointOpcode(Opc); + } +} + +static bool isAmbiguous(unsigned Opc) { + switch (Opc) { + case TargetOpcode::G_LOAD: + case TargetOpcode::G_STORE: + case TargetOpcode::G_PHI: + case TargetOpcode::G_SELECT: + case TargetOpcode::G_IMPLICIT_DEF: + case TargetOpcode::G_UNMERGE_VALUES: + case TargetOpcode::G_MERGE_VALUES: + return true; + default: + return false; + } +} + +const RegisterBankInfo::InstructionMapping & +WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { + + unsigned Opc = MI.getOpcode(); + const MachineFunction &MF = *MI.getParent()->getParent(); + const MachineRegisterInfo &MRI = MF.getRegInfo(); + const TargetSubtargetInfo &STI = MF.getSubtarget(); + const TargetRegisterInfo &TRI = *STI.getRegisterInfo(); + + if ((Opc != TargetOpcode::COPY && !isPreISelGenericOpcode(Opc)) || + Opc == TargetOpcode::G_PHI) { + const RegisterBankInfo::InstructionMapping &Mapping = + getInstrMappingImpl(MI); + if (Mapping.isValid()) + return Mapping; + } + + using namespace TargetOpcode; + + unsigned NumOperands = MI.getNumOperands(); + const ValueMapping *OperandsMapping = nullptr; + unsigned MappingID = DefaultMappingID; + + // Check if LLT sizes match sizes of available register banks. + for (const MachineOperand &Op : MI.operands()) { + if (Op.isReg()) { + LLT RegTy = MRI.getType(Op.getReg()); + + if (RegTy.isScalar() && + (RegTy.getSizeInBits() != 32 && RegTy.getSizeInBits() != 64)) + return getInvalidInstructionMapping(); + + if (RegTy.isVector() && RegTy.getSizeInBits() != 128) + return getInvalidInstructionMapping(); + } + } + + switch (Opc) { + case G_BR: + return getInstructionMapping(MappingID, /*Cost=*/1, + getOperandsMapping({nullptr}), NumOperands); + case G_TRAP: + return getInstructionMapping(MappingID, /*Cost=*/1, nullptr, 0); + } + + const LLT Op0Ty = MRI.getType(MI.getOperand(0).getReg()); + unsigned Op0Size = Op0Ty.getSizeInBits(); + + auto &Op0IntValueMapping = + WebAssembly::ValueMappings[Op0Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + auto &Op0FloatValueMapping = + WebAssembly::ValueMappings[Op0Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + auto &Pointer0ValueMapping = + WebAssembly::ValueMappings[MI.getMF()->getDataLayout() + .getPointerSizeInBits(0) == 64 + ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + switch (Opc) { + case G_AND: + case G_OR: + case G_XOR: + case G_SHL: + case G_ASHR: + case G_LSHR: + case G_PTR_ADD: + case G_INTTOPTR: + case G_PTRTOINT: + case G_ADD: + case G_SUB: + case G_MUL: + case G_SDIV: + case G_SREM: + case G_UDIV: + case G_UREM: + OperandsMapping = &Op0IntValueMapping; + break; + case G_SEXT_INREG: + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Op0IntValueMapping, nullptr}); + break; + case G_FRAME_INDEX: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case G_ZEXT: + case G_ANYEXT: + case G_SEXT: + case G_TRUNC: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1IntValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Op1IntValueMapping}); + break; + } + case G_LOAD: + case G_STORE: + if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0) + break; + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Pointer0ValueMapping}); + break; + case G_MEMCPY: + case G_MEMMOVE: { + if (MRI.getType(MI.getOperand(0).getReg()).getAddressSpace() != 0) + break; + if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0) + break; + + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op2Ty.getSizeInBits(); + auto &Op2IntValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + OperandsMapping = + getOperandsMapping({&Pointer0ValueMapping, &Pointer0ValueMapping, + &Op2IntValueMapping, nullptr}); + break; + } + case G_MEMSET: { + if (MRI.getType(MI.getOperand(0).getReg()).getAddressSpace() != 0) + break; + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + auto &Op1IntValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op1Ty.getSizeInBits(); + auto &Op2IntValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + OperandsMapping = + getOperandsMapping({&Pointer0ValueMapping, &Op1IntValueMapping, + &Op2IntValueMapping, nullptr}); + break; + } + case G_GLOBAL_VALUE: + case G_CONSTANT: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case G_IMPLICIT_DEF: + OperandsMapping = &Op0IntValueMapping; + break; + case G_ICMP: { + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op2Ty.getSizeInBits(); + + auto &Op2IntValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, nullptr, &Op2IntValueMapping, + &Op2IntValueMapping}); + break; + } + case G_BRCOND: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case COPY: { + Register DstReg = MI.getOperand(0).getReg(); + Register SrcReg = MI.getOperand(1).getReg(); + // Check if one of the register is not a generic register. + if ((DstReg.isPhysical() || !MRI.getType(DstReg).isValid()) || + (SrcReg.isPhysical() || !MRI.getType(SrcReg).isValid())) { + const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI); + const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI); + if (!DstRB) + DstRB = SrcRB; + else if (!SrcRB) + SrcRB = DstRB; + // If both RB are null that means both registers are generic. + // We shouldn't be here. + assert(DstRB && SrcRB && "Both RegBank were nullptr"); + TypeSize DstSize = getSizeInBits(DstReg, MRI, TRI); + TypeSize SrcSize = getSizeInBits(SrcReg, MRI, TRI); + assert(DstSize == SrcSize && + "Trying to copy between different sized regbanks? Why?"); + + return getInstructionMapping( + DefaultMappingID, copyCost(*DstRB, *SrcRB, DstSize), + getCopyMapping(DstRB->getID(), SrcRB->getID(), Size), + // We only care about the mapping of the destination. + /*NumOperands*/ 1); + } + } + } + if (!OperandsMapping) + return getInvalidInstructionMapping(); + + return getInstructionMapping(MappingID, /*Cost=*/1, OperandsMapping, + NumOperands); +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h index e69de29bb2d1d..f0d95b56ef861 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h @@ -0,0 +1,40 @@ +//===- WebAssemblyRegisterBankInfo.h ----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file declares the targeting of the RegisterBankInfo class for WASM. +/// \todo This should be generated by TableGen. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H +#define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H + +#include "llvm/CodeGen/RegisterBankInfo.h" + +#define GET_REGBANK_DECLARATIONS +#include "WebAssemblyGenRegisterBank.inc" + +namespace llvm { + +class TargetRegisterInfo; + +class WebAssemblyGenRegisterBankInfo : public RegisterBankInfo { +#define GET_TARGET_REGBANK_CLASS +#include "WebAssemblyGenRegisterBank.inc" +}; + +/// This class provides the information for the target register banks. +class WebAssemblyRegisterBankInfo final + : public WebAssemblyGenRegisterBankInfo { +public: + WebAssemblyRegisterBankInfo(const TargetRegisterInfo &TRI); + + const InstructionMapping & + getInstrMapping(const MachineInstr &MI) const override; +}; +} // end namespace llvm +#endif diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.td b/llvm/lib/Target/WebAssembly/WebAssembly.td index 089be5f1dc70e..3705a42fd21c9 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.td +++ b/llvm/lib/Target/WebAssembly/WebAssembly.td @@ -101,6 +101,7 @@ def FeatureWideArithmetic : //===----------------------------------------------------------------------===// include "WebAssemblyRegisterInfo.td" +include "WebAssemblyRegisterBanks.td" //===----------------------------------------------------------------------===// // Instruction Descriptions diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td new file mode 100644 index 0000000000000..9ebece0e0bf09 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td @@ -0,0 +1,20 @@ +//=- WebAssemblyRegisterBank.td - Describe the WASM Banks ----*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + + +def I32RegBank : RegisterBank<"I32RegBank", [I32]>; +def I64RegBank : RegisterBank<"I64RegBank", [I64]>; +def F32RegBank : RegisterBank<"F64RegBank", [F32]>; +def F64RegBank : RegisterBank<"F64RegBank", [F64]>; + +def EXTERNREFRegBank : RegisterBank<"EXTERNREFRegBank", [EXTERNREF]>; +def FUNCREFRegBank : RegisterBank<"FUNCREFRegBank", [FUNCREF]>; +def EXNREFRegBank : RegisterBank<"EXNREFRegBank", [EXNREF]>; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp index 3ea8b9f85819f..b99c35acabef6 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp @@ -73,9 +73,9 @@ WebAssemblySubtarget::WebAssemblySubtarget(const Triple &TT, TLInfo(TM, *this) { CallLoweringInfo.reset(new WebAssemblyCallLowering(*getTargetLowering())); Legalizer.reset(new WebAssemblyLegalizerInfo(*this)); - /*auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo()); + auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo()); RegBankInfo.reset(RBI); - +/* InstSelector.reset(createWebAssemblyInstructionSelector( *static_cast(&TM), *this, *RBI));*/ } From c0c68c5a03c5724b2195218ad4b927b3e77eb302 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Mon, 29 Sep 2025 11:23:22 -0700 Subject: [PATCH 09/17] Finish initial pass over regbankselect --- .../GISel/WebAssemblyCallLowering.cpp | 21 +- .../GISel/WebAssemblyLegalizerInfo.cpp | 55 ++-- .../GISel/WebAssemblyRegisterBankInfo.cpp | 254 +++++++++++++----- 3 files changed, 241 insertions(+), 89 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 7f3c7e62ce02d..733d676ac988a 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -33,6 +33,7 @@ #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineMemOperand.h" #include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/CodeGenTypes/LowLevelType.h" #include "llvm/IR/Argument.h" #include "llvm/IR/DataLayout.h" @@ -481,6 +482,7 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); LLT OrigLLT = getLLTForType(*Arg.Ty, DL); LLT NewLLT = getLLTForMVT(NewVT); + const TargetRegisterClass &NewRegClass = *TLI.getRegClassFor(NewVT); // If we need to split the type over multiple regs, check it's a scenario // we currently support. @@ -522,10 +524,12 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, } for (unsigned Part = 0; Part < NumParts; ++Part) { - auto NewOutReg = constrainRegToClass(MRI, TII, RBI, Arg.Regs[Part], - *TLI.getRegClassFor(NewVT)); - if (NewOutReg != Arg.Regs[Part]) + auto NewOutReg = Arg.Regs[Part]; + if (!RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI)) { + NewOutReg = MRI.createGenericVirtualRegister(NewLLT); + assert(RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI) && "Couldn't constrain brand-new register?"); MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]); + } MIB.addUse(NewOutReg); } } @@ -864,6 +868,7 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, MVT NewVT = TLI.getRegisterTypeForCallingConv(Ctx, CallConv, OrigVT); LLT OrigLLT = getLLTForType(*Ret.Ty, DL); LLT NewLLT = getLLTForMVT(NewVT); + const TargetRegisterClass &NewRegClass = *TLI.getRegClassFor(NewVT); // If we need to split the type over multiple regs, check it's a // scenario we currently support. @@ -903,12 +908,12 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, } for (unsigned Part = 0; Part < NumParts; ++Part) { - // MRI.setRegClass(Ret.Regs[Part], TLI.getRegClassFor(NewVT)); - auto NewRetReg = constrainRegToClass(MRI, TII, RBI, Ret.Regs[Part], - *TLI.getRegClassFor(NewVT)); - if (Ret.Regs[Part] != NewRetReg) + auto NewRetReg = Ret.Regs[Part]; + if (!RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI)) { + NewRetReg = MRI.createGenericVirtualRegister(NewLLT); + assert(RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI) && "Couldn't constrain brand-new register?"); MIRBuilder.buildCopy(NewRetReg, Ret.Regs[Part]); - + } CallInst.addDef(Ret.Regs[Part]); } } diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index c6cd1c5b371e9..ae2ac0a512427 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -89,14 +89,17 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( .clampScalar(0, s32, s64); getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF, - G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_FSHL, - G_FSHR}) + G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP}) .legalFor({{s32, s32}, {s64, s64}}) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64) .minScalarSameAs(1, 0) .maxScalarSameAs(1, 0); + getActionDefinitionsBuilder({G_FSHL, G_FSHR}) + .legalFor({{s32, s32}, {s64, s64}}) + .lower(); + getActionDefinitionsBuilder({G_SCMP, G_UCMP}).lower(); getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) @@ -123,6 +126,12 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( .libcallFor({s32, s64}) .minScalar(0, s32); + getActionDefinitionsBuilder(G_LROUND).libcallForCartesianProduct({s32}, + {s32, s64}); + + getActionDefinitionsBuilder(G_LLROUND).libcallForCartesianProduct({s64}, + {s32, s64}); + getActionDefinitionsBuilder(G_FCOPYSIGN) .legalFor({s32, s64}) .minScalar(0, s32) @@ -154,9 +163,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( {s64, p0, s8, 1}, {s64, p0, s16, 1}, {s64, p0, s32, 1}}) - .widenScalarToNextPow2(0) - .lowerIfMemSizeNotByteSizePow2() - .clampScalar(0, s32, s64); + .clampScalar(0, s32, s64) + .lowerIfMemSizeNotByteSizePow2(); getActionDefinitionsBuilder(G_STORE) .legalForTypesWithMemDesc( @@ -167,9 +175,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( {s64, p0, s8, 1}, {s64, p0, s16, 1}, {s64, p0, s32, 1}}) - .widenScalarToNextPow2(0) - .lowerIfMemSizeNotByteSizePow2() - .clampScalar(0, s32, s64); + .clampScalar(0, s32, s64) + .lowerIfMemSizeNotByteSizePow2(); getActionDefinitionsBuilder({G_ZEXTLOAD, G_SEXTLOAD}) .legalForTypesWithMemDesc({{s32, p0, s8, 1}, @@ -178,10 +185,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( {s64, p0, s8, 1}, {s64, p0, s16, 1}, {s64, p0, s32, 1}}) - .widenScalarToNextPow2(0) - .lowerIfMemSizeNotByteSizePow2() .clampScalar(0, s32, s64) - .lower(); + .lowerIfMemSizeNotByteSizePow2(); if (ST.hasBulkMemoryOpt()) { getActionDefinitionsBuilder(G_BZERO).unsupported(); @@ -204,7 +209,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( // TODO: figure out how to combine G_ANYEXT of G_ASSERT_{S|Z}EXT (or // appropriate G_AND and G_SEXT_IN_REG?) to a G_{S|Z}EXT + G_ASSERT_{S|Z}EXT - // for better optimization (since G_ANYEXT lowers to a ZEXT or SEXT + // for better optimization (since G_ANYEXT will lower to a ZEXT or SEXT // instruction anyway). getActionDefinitionsBuilder(G_ANYEXT) @@ -221,8 +226,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( if (ST.hasSignExt()) { getActionDefinitionsBuilder(G_SEXT_INREG) .clampScalar(0, s32, s64) - .customFor({s32, s64}) - .lower(); + .customFor({s32, s64}); } else { getActionDefinitionsBuilder(G_SEXT_INREG).lower(); } @@ -242,23 +246,42 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( .legalForCartesianProduct({s32, s64}, {p0}) .clampScalar(0, s32, s64); + getActionDefinitionsBuilder(G_DYN_STACKALLOC).lowerFor({{p0, p0s}}); + + getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).lower(); + getLegacyLegalizerInfo().computeTables(); } bool WebAssemblyLegalizerInfo::legalizeCustom( LegalizerHelper &Helper, MachineInstr &MI, LostDebugLocObserver &LocObserver) const { + auto &MRI = *Helper.MIRBuilder.getMRI(); + auto &MIRBuilder = Helper.MIRBuilder; + switch (MI.getOpcode()) { case TargetOpcode::G_SEXT_INREG: { + assert(MI.getOperand(2).isImm() && "Expected immediate"); + // Mark only 8/16/32-bit SEXT_INREG as legal - auto [DstType, SrcType] = MI.getFirst2LLTs(); + auto [DstReg, SrcReg] = MI.getFirst2Regs(); + auto DstType = MRI.getType(DstReg); auto ExtFromWidth = MI.getOperand(2).getImm(); if (ExtFromWidth == 8 || ExtFromWidth == 16 || (DstType.getScalarSizeInBits() == 64 && ExtFromWidth == 32)) { return true; } - return false; + + Register TmpRes = MRI.createGenericVirtualRegister(DstType); + + auto MIBSz = MIRBuilder.buildConstant( + DstType, DstType.getScalarSizeInBits() - ExtFromWidth); + MIRBuilder.buildShl(TmpRes, SrcReg, MIBSz->getOperand(0)); + MIRBuilder.buildAShr(DstReg, TmpRes, MIBSz->getOperand(0)); + MI.eraseFromParent(); + + return true; } case TargetOpcode::G_MEMSET: { // Anyext the value being set to 32 bit (only the bottom 8 bits are read by diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp index e605c46aece85..fa4103a8b1b31 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp @@ -1,7 +1,9 @@ #include "WebAssemblyRegisterBankInfo.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" #include "WebAssemblySubtarget.h" #include "WebAssemblyTargetMachine.h" #include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/Support/ErrorHandling.h" #define GET_TARGET_REGBANK_IMPL @@ -59,46 +61,6 @@ using namespace llvm; WebAssemblyRegisterBankInfo::WebAssemblyRegisterBankInfo( const TargetRegisterInfo &TRI) {} -// Instructions where use operands are floating point registers. -// Def operands are general purpose. -static bool isFloatingPointOpcodeUse(unsigned Opc) { - switch (Opc) { - case TargetOpcode::G_FPTOSI: - case TargetOpcode::G_FPTOUI: - case TargetOpcode::G_FCMP: - return true; - default: - return isPreISelGenericFloatingPointOpcode(Opc); - } -} - -// Instructions where def operands are floating point registers. -// Use operands are general purpose. -static bool isFloatingPointOpcodeDef(unsigned Opc) { - switch (Opc) { - case TargetOpcode::G_SITOFP: - case TargetOpcode::G_UITOFP: - return true; - default: - return isPreISelGenericFloatingPointOpcode(Opc); - } -} - -static bool isAmbiguous(unsigned Opc) { - switch (Opc) { - case TargetOpcode::G_LOAD: - case TargetOpcode::G_STORE: - case TargetOpcode::G_PHI: - case TargetOpcode::G_SELECT: - case TargetOpcode::G_IMPLICIT_DEF: - case TargetOpcode::G_UNMERGE_VALUES: - case TargetOpcode::G_MERGE_VALUES: - return true; - default: - return false; - } -} - const RegisterBankInfo::InstructionMapping & WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { @@ -135,13 +97,35 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { return getInvalidInstructionMapping(); } } - switch (Opc) { case G_BR: return getInstructionMapping(MappingID, /*Cost=*/1, getOperandsMapping({nullptr}), NumOperands); case G_TRAP: - return getInstructionMapping(MappingID, /*Cost=*/1, nullptr, 0); + case G_DEBUGTRAP: + return getInstructionMapping(MappingID, /*Cost=*/1, getOperandsMapping({}), + 0); + case COPY: + Register DstReg = MI.getOperand(0).getReg(); + if (DstReg.isPhysical()) { + if (DstReg.id() == WebAssembly::SP32) { + return getInstructionMapping( + MappingID, /*Cost=*/1, + getOperandsMapping( + {&WebAssembly::ValueMappings[WebAssembly::I32Idx]}), + 1); + } else if (DstReg.id() == WebAssembly::SP64) { + return getInstructionMapping( + MappingID, /*Cost=*/1, + getOperandsMapping( + {&WebAssembly::ValueMappings[WebAssembly::I64Idx]}), + 1); + } else { + llvm_unreachable("Trying to copy into WASM physical register other " + "than sp32 or sp64?"); + } + } + break; } const LLT Op0Ty = MRI.getType(MI.getOperand(0).getReg()); @@ -176,8 +160,39 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case G_SREM: case G_UDIV: case G_UREM: + case G_CTLZ: + case G_CTLZ_ZERO_UNDEF: + case G_CTTZ: + case G_CTTZ_ZERO_UNDEF: + case G_CTPOP: + case G_FSHL: + case G_FSHR: OperandsMapping = &Op0IntValueMapping; break; + case G_FADD: + case G_FSUB: + case G_FDIV: + case G_FMUL: + case G_FNEG: + case G_FABS: + case G_FCEIL: + case G_FFLOOR: + case G_FSQRT: + case G_INTRINSIC_TRUNC: + case G_FNEARBYINT: + case G_FRINT: + case G_INTRINSIC_ROUNDEVEN: + case G_FMINIMUM: + case G_FMAXIMUM: + case G_FMINNUM: + case G_FMAXNUM: + case G_FMINNUM_IEEE: + case G_FMAXNUM_IEEE: + case G_FMA: + case G_FREM: + case G_FCOPYSIGN: + OperandsMapping = &Op0FloatValueMapping; + break; case G_SEXT_INREG: OperandsMapping = getOperandsMapping({&Op0IntValueMapping, &Op0IntValueMapping, nullptr}); @@ -185,6 +200,9 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case G_FRAME_INDEX: OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); break; + case G_VASTART: + OperandsMapping = &Op0IntValueMapping; + break; case G_ZEXT: case G_ANYEXT: case G_SEXT: @@ -233,7 +251,7 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { : WebAssembly::I32Idx]; const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); - unsigned Op2Size = Op1Ty.getSizeInBits(); + unsigned Op2Size = Op2Ty.getSizeInBits(); auto &Op2IntValueMapping = WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::I64Idx : WebAssembly::I32Idx]; @@ -247,6 +265,9 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case G_CONSTANT: OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); break; + case G_FCONSTANT: + OperandsMapping = getOperandsMapping({&Op0FloatValueMapping, nullptr}); + break; case G_IMPLICIT_DEF: OperandsMapping = &Op0IntValueMapping; break; @@ -263,37 +284,140 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { &Op2IntValueMapping}); break; } + case G_FCMP: { + const LLT Op2Ty = MRI.getType(MI.getOperand(2).getReg()); + unsigned Op2Size = Op2Ty.getSizeInBits(); + + auto &Op2FloatValueMapping = + WebAssembly::ValueMappings[Op2Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, nullptr, &Op2FloatValueMapping, + &Op2FloatValueMapping}); + break; + } case G_BRCOND: OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); break; + case G_JUMP_TABLE: + OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); + break; + case G_BRJT: + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, nullptr, + &WebAssembly::ValueMappings[WebAssembly::I32Idx]}); + break; case COPY: { Register DstReg = MI.getOperand(0).getReg(); Register SrcReg = MI.getOperand(1).getReg(); - // Check if one of the register is not a generic register. - if ((DstReg.isPhysical() || !MRI.getType(DstReg).isValid()) || - (SrcReg.isPhysical() || !MRI.getType(SrcReg).isValid())) { - const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI); - const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI); - if (!DstRB) - DstRB = SrcRB; - else if (!SrcRB) - SrcRB = DstRB; - // If both RB are null that means both registers are generic. - // We shouldn't be here. - assert(DstRB && SrcRB && "Both RegBank were nullptr"); - TypeSize DstSize = getSizeInBits(DstReg, MRI, TRI); - TypeSize SrcSize = getSizeInBits(SrcReg, MRI, TRI); - assert(DstSize == SrcSize && - "Trying to copy between different sized regbanks? Why?"); - - return getInstructionMapping( - DefaultMappingID, copyCost(*DstRB, *SrcRB, DstSize), - getCopyMapping(DstRB->getID(), SrcRB->getID(), Size), - // We only care about the mapping of the destination. - /*NumOperands*/ 1); + + const RegisterBank *DstRB = getRegBank(DstReg, MRI, TRI); + const RegisterBank *SrcRB = getRegBank(SrcReg, MRI, TRI); + + if (!DstRB) + DstRB = SrcRB; + else if (!SrcRB) + SrcRB = DstRB; + + assert(DstRB && SrcRB && "Both RegBank were nullptr"); + TypeSize DstSize = getSizeInBits(DstReg, MRI, TRI); + TypeSize SrcSize = getSizeInBits(SrcReg, MRI, TRI); + assert(DstSize == SrcSize && + "Trying to copy between different sized regbanks? Why?"); + + WebAssembly::ValueMappingIdx DstValMappingIdx = WebAssembly::InvalidIdx; + switch (DstRB->getID()) { + case WebAssembly::I32RegBankID: + DstValMappingIdx = WebAssembly::I32Idx; + break; + case WebAssembly::I64RegBankID: + DstValMappingIdx = WebAssembly::I64Idx; + break; + case WebAssembly::F32RegBankID: + DstValMappingIdx = WebAssembly::F32Idx; + break; + case WebAssembly::F64RegBankID: + DstValMappingIdx = WebAssembly::F64Idx; + break; + default: + break; + } + + WebAssembly::ValueMappingIdx SrcValMappingIdx = WebAssembly::InvalidIdx; + switch (SrcRB->getID()) { + case WebAssembly::I32RegBankID: + SrcValMappingIdx = WebAssembly::I32Idx; + break; + case WebAssembly::I64RegBankID: + SrcValMappingIdx = WebAssembly::I64Idx; + break; + case WebAssembly::F32RegBankID: + SrcValMappingIdx = WebAssembly::F32Idx; + break; + case WebAssembly::F64RegBankID: + SrcValMappingIdx = WebAssembly::F64Idx; + break; + default: + break; } + + OperandsMapping = + getOperandsMapping({&WebAssembly::ValueMappings[DstValMappingIdx], + &WebAssembly::ValueMappings[SrcValMappingIdx]}); + return getInstructionMapping( + MappingID, /*Cost=*/copyCost(*DstRB, *SrcRB, DstSize), OperandsMapping, + // We only care about the mapping of the destination for COPY. + 1); + } + case G_SELECT: + OperandsMapping = getOperandsMapping( + {&Op0IntValueMapping, &WebAssembly::ValueMappings[WebAssembly::I32Idx], + &Op0IntValueMapping, &Op0IntValueMapping}); + break; + case G_FPTOSI: + case G_FPTOSI_SAT: + case G_FPTOUI: + case G_FPTOUI_SAT: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1FloatValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Op1FloatValueMapping}); + break; } + case G_SITOFP: + case G_UITOFP: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1IntValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::I64Idx + : WebAssembly::I32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0FloatValueMapping, &Op1IntValueMapping}); + break; } + case G_FPEXT: + case G_FPTRUNC: { + const LLT Op1Ty = MRI.getType(MI.getOperand(1).getReg()); + unsigned Op1Size = Op1Ty.getSizeInBits(); + + auto &Op1FloatValueMapping = + WebAssembly::ValueMappings[Op1Size == 64 ? WebAssembly::F64Idx + : WebAssembly::F32Idx]; + + OperandsMapping = + getOperandsMapping({&Op0FloatValueMapping, &Op1FloatValueMapping}); + break; + } + } + if (!OperandsMapping) return getInvalidInstructionMapping(); From 3ca5826904adf63f19074b18bfb4d2a320b48f24 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Thu, 2 Oct 2025 00:18:46 -0700 Subject: [PATCH 10/17] Begin work on instruction selection --- llvm/lib/Target/WebAssembly/CMakeLists.txt | 1 + .../GISel/WebAssemblyInstructionSelector.cpp | 526 ++++++++++++++++++ .../GISel/WebAssemblyInstructionSelector.h | 0 llvm/lib/Target/WebAssembly/WebAssembly.h | 9 + .../WebAssembly/WebAssemblyInstrMemory.td | 6 + .../WebAssembly/WebAssemblySubtarget.cpp | 5 +- 6 files changed, 545 insertions(+), 2 deletions(-) delete mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index e0be3d6af2f0c..e80850eb073fb 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -7,6 +7,7 @@ tablegen(LLVM WebAssemblyGenAsmWriter.inc -gen-asm-writer) tablegen(LLVM WebAssemblyGenDAGISel.inc -gen-dag-isel) tablegen(LLVM WebAssemblyGenDisassemblerTables.inc -gen-disassembler) tablegen(LLVM WebAssemblyGenFastISel.inc -gen-fast-isel) +tablegen(LLVM WebAssemblyGenGlobalISel.inc -gen-global-isel) tablegen(LLVM WebAssemblyGenInstrInfo.inc -gen-instr-info) tablegen(LLVM WebAssemblyGenMCCodeEmitter.inc -gen-emitter) tablegen(LLVM WebAssemblyGenRegisterBank.inc -gen-register-bank) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp index e69de29bb2d1d..aea7b9a424a62 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp @@ -0,0 +1,526 @@ +//===- WebAssemblyInstructionSelector.cpp ------------------------*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file implements the targeting of the InstructionSelector class for +/// WebAssembly. +/// \todo This should be generated by TableGen. +//===----------------------------------------------------------------------===// + +#include "GISel/WebAssemblyRegisterBankInfo.h" +#include "MCTargetDesc/WebAssemblyMCTargetDesc.h" +#include "Utils/WasmAddressSpaces.h" +#include "Utils/WebAssemblyTypeUtilities.h" +#include "WebAssemblyRegisterInfo.h" +#include "WebAssemblySubtarget.h" +#include "WebAssemblyTargetMachine.h" +#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/GlobalISel/Utils.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/RegisterBank.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/IntrinsicsWebAssembly.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/ErrorHandling.h" + +#define DEBUG_TYPE "wasm-isel" + +using namespace llvm; + +namespace { + +#define GET_GLOBALISEL_PREDICATE_BITSET +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATE_BITSET + +class WebAssemblyInstructionSelector : public InstructionSelector { +public: + WebAssemblyInstructionSelector(const WebAssemblyTargetMachine &TM, + const WebAssemblySubtarget &STI, + const WebAssemblyRegisterBankInfo &RBI); + + bool select(MachineInstr &I) override; + + InstructionSelector::ComplexRendererFns + selectAddrOperands32(MachineOperand &Root) const; + InstructionSelector::ComplexRendererFns + selectAddrOperands64(MachineOperand &Root) const; + + static const char *getName() { return DEBUG_TYPE; } + +private: + bool selectImpl(MachineInstr &I, CodeGenCoverage &CoverageInfo) const; + bool selectCopy(MachineInstr &I, MachineRegisterInfo &MRI) const; + + InstructionSelector::ComplexRendererFns + selectAddrOperands(LLT AddrType, unsigned int ConstOpc, + MachineOperand &Root) const; + + const WebAssemblyTargetMachine &TM; + const WebAssemblySubtarget &STI; + const WebAssemblyInstrInfo &TII; + const WebAssemblyRegisterInfo &TRI; + const WebAssemblyRegisterBankInfo &RBI; + +#define GET_GLOBALISEL_PREDICATES_DECL +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_DECL + +#define GET_GLOBALISEL_TEMPORARIES_DECL +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_DECL +}; + +} // end anonymous namespace + +#define GET_GLOBALISEL_IMPL +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_IMPL + +WebAssemblyInstructionSelector::WebAssemblyInstructionSelector( + const WebAssemblyTargetMachine &TM, const WebAssemblySubtarget &STI, + const WebAssemblyRegisterBankInfo &RBI) + : TM(TM), STI(STI), TII(*STI.getInstrInfo()), TRI(*STI.getRegisterInfo()), + RBI(RBI), + +#define GET_GLOBALISEL_PREDICATES_INIT +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_PREDICATES_INIT +#define GET_GLOBALISEL_TEMPORARIES_INIT +#include "WebAssemblyGenGlobalISel.inc" +#undef GET_GLOBALISEL_TEMPORARIES_INIT +{ +} + +InstructionSelector::ComplexRendererFns +WebAssemblyInstructionSelector::selectAddrOperands(LLT AddrType, + unsigned int ConstOpc, + MachineOperand &Root) const { + + if (!Root.isReg()) + return std::nullopt; + + MachineRegisterInfo &MRI = + Root.getParent()->getParent()->getParent()->getRegInfo(); + MachineInstr &RootDef = *MRI.getVRegDef(Root.getReg()); + + if (RootDef.getOpcode() == TargetOpcode::G_PTR_ADD) { + // RootDef will always be G_PTR_ADD + MachineOperand &LHS = RootDef.getOperand(1); + + MachineOperand &RHS = RootDef.getOperand(2); + MachineInstr &LHSDef = *MRI.getVRegDef(LHS.getReg()); + MachineInstr &RHSDef = + *MRI.getVRegDef(RHS.getReg()); // Will always be G_CONSTANT + + // WebAssembly constant offsets are performed as unsigned with infinite + // precision, so we need to check for NoUnsignedWrap so that we don't fold + // and offset for an add that needs wrapping. + if (RootDef.getFlag(MachineInstr::MIFlag::NoUWrap)) { + for (size_t i = 0; i < 2; ++i) { + //MachineOperand &Op = i == 0 ? LHS : RHS; + MachineInstr &OpDef = i == 0 ? LHSDef : RHSDef; + MachineOperand &OtherOp = i == 0 ? RHS : LHS; + //MachineInstr &OtherOpDef = i == 0 ? RHSDef : LHSDef; + + if (OpDef.getOpcode() == TargetOpcode::G_CONSTANT) { + auto Offset = OpDef.getOperand(1).getCImm()->getZExtValue(); + auto Addr = OtherOp; + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addImm(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.add(Addr); }, + }}; + } + + if (!TM.isPositionIndependent()) { + if (OpDef.getOpcode() == TargetOpcode::G_GLOBAL_VALUE) { + auto Offset = OpDef.getOperand(1).getGlobal(); + auto Addr = OtherOp; + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addGlobalAddress(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.add(Addr); }, + }}; + } + } + } + } + } + + if (RootDef.getOpcode() == TargetOpcode::G_CONSTANT) { + auto Offset = RootDef.getOperand(1).getCImm()->getZExtValue(); + auto Addr = MRI.createGenericVirtualRegister(AddrType); + + MachineIRBuilder B(RootDef); + + auto MIB = B.buildInstr(ConstOpc).addDef(Addr).addImm(0); + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addImm(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.addReg(Addr); }, + }}; + } + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addImm(0); }, + [=](MachineInstrBuilder &MIB) { MIB.add(Root); }, + }}; +} + +InstructionSelector::ComplexRendererFns +WebAssemblyInstructionSelector::selectAddrOperands32( + MachineOperand &Root) const { + return selectAddrOperands(LLT::scalar(32), WebAssembly::CONST_I32, Root); +} + +InstructionSelector::ComplexRendererFns +WebAssemblyInstructionSelector::selectAddrOperands64( + MachineOperand &Root) const { + return selectAddrOperands(LLT::scalar(64), WebAssembly::CONST_I64, Root); +} + +bool WebAssemblyInstructionSelector::selectCopy( + MachineInstr &I, MachineRegisterInfo &MRI) const { + Register DstReg = I.getOperand(0).getReg(); + Register SrcReg = I.getOperand(1).getReg(); + + if (DstReg.isPhysical()) { + if (DstReg.id() == WebAssembly::SP32) { + if (!RBI.constrainGenericRegister(DstReg, WebAssembly::I32RegClass, + MRI)) { + LLVM_DEBUG(dbgs() << "Failed to constrain " + << TII.getName(I.getOpcode()) << " operand\n"); + return false; + } + return true; + } + if (DstReg.id() == WebAssembly::SP64) { + if (!RBI.constrainGenericRegister(DstReg, WebAssembly::I64RegClass, + MRI)) { + LLVM_DEBUG(dbgs() << "Failed to constrain " + << TII.getName(I.getOpcode()) << " operand\n"); + return false; + } + return true; + } + llvm_unreachable("Copy to physical register other than SP32 or SP64?"); + } + + const TargetRegisterClass *DstRC = MRI.getRegClassOrNull(DstReg); + if (!DstRC) { + const RegisterBank *DstBank = MRI.getRegBankOrNull(DstReg); + if (!DstBank) { + llvm_unreachable("Selecting copy with dst reg with no bank?"); + } + + switch (DstBank->getID()) { + case WebAssembly::I32RegBankID: + DstRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::I64RegBankID: + DstRC = &WebAssembly::I64RegClass; + break; + case WebAssembly::F32RegBankID: + DstRC = &WebAssembly::F32RegClass; + break; + case WebAssembly::F64RegBankID: + DstRC = &WebAssembly::F64RegClass; + break; + default: + llvm_unreachable("Unknown reg bank to reg class mapping?"); + } + if (!RBI.constrainGenericRegister(DstReg, *DstRC, MRI)) { + LLVM_DEBUG(dbgs() << "Failed to constrain " << TII.getName(I.getOpcode()) + << " operand\n"); + return false; + } + } + + const TargetRegisterClass *SrcRC = MRI.getRegClassOrNull(SrcReg); + if (!SrcRC) { + const RegisterBank *SrcBank = MRI.getRegBankOrNull(SrcReg); + if (!SrcBank) { + llvm_unreachable("Selecting copy with src reg with no bank?"); + } + + switch (SrcBank->getID()) { + case WebAssembly::I32RegBankID: + SrcRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::I64RegBankID: + SrcRC = &WebAssembly::I64RegClass; + break; + case WebAssembly::F32RegBankID: + SrcRC = &WebAssembly::F32RegClass; + break; + case WebAssembly::F64RegBankID: + SrcRC = &WebAssembly::F64RegClass; + break; + default: + llvm_unreachable("Unknown reg bank to reg class mapping?"); + } + if (!RBI.constrainGenericRegister(SrcReg, *SrcRC, MRI)) { + LLVM_DEBUG(dbgs() << "Failed to constrain " << TII.getName(I.getOpcode()) + << " operand\n"); + return false; + } + } + + assert(TRI.getRegSizeInBits(*DstRC) == TRI.getRegSizeInBits(*SrcRC) && + "Copy between mismatching register sizes?"); + + if (DstRC != SrcRC) { + if (DstRC == &WebAssembly::I32RegClass && + SrcRC == &WebAssembly::F32RegClass) { + I.setDesc(TII.get(WebAssembly::I32_REINTERPRET_F32)); + return true; + } + + if (DstRC == &WebAssembly::F32RegClass && + SrcRC == &WebAssembly::I32RegClass) { + I.setDesc(TII.get(WebAssembly::F32_REINTERPRET_I32)); + return true; + } + + if (DstRC == &WebAssembly::I64RegClass && + SrcRC == &WebAssembly::F64RegClass) { + I.setDesc(TII.get(WebAssembly::I64_REINTERPRET_F64)); + return true; + } + + if (DstRC == &WebAssembly::F64RegClass && + SrcRC == &WebAssembly::I64RegClass) { + I.setDesc(TII.get(WebAssembly::F64_REINTERPRET_I64)); + return true; + } + + llvm_unreachable("Found bitcast/copy edge case."); + } + + return true; +} + +bool WebAssemblyInstructionSelector::select(MachineInstr &I) { + MachineBasicBlock &MBB = *I.getParent(); + MachineFunction &MF = *MBB.getParent(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + const TargetLowering &TLI = *STI.getTargetLowering(); + + if (!isPreISelGenericOpcode(I.getOpcode())) { + if (I.isCopy()) + return selectCopy(I, MRI); + + return true; + } + + if (selectImpl(I, *CoverageInfo)) + return true; + + using namespace TargetOpcode; + + auto PointerWidth = MF.getDataLayout().getPointerSizeInBits(); + auto PtrIsI64 = PointerWidth == 64; + + switch (I.getOpcode()) { + case G_CONSTANT: { + assert(MRI.getType(I.getOperand(0).getReg()).isPointer() && + "G_CONSTANT selection fell-through with non-pointer?"); + + auto OrigImm = I.getOperand(1).getCImm()->getValue(); + + auto MaskedVal = OrigImm.getLoBits(PointerWidth); + assert(MaskedVal.eq(OrigImm) && + "Pointer immediate uses more bits than allowed"); + + I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::CONST_I64 + : WebAssembly::CONST_I32)); + I.removeOperand(1); + I.addOperand(MachineOperand::CreateImm(MaskedVal.getZExtValue())); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + return true; + } + case G_PTR_ADD: { + assert(MRI.getType(I.getOperand(0).getReg()).isPointer() && + "G_PTR_ADD selection fell-through with non-pointer?"); + + auto PointerWidth = MF.getDataLayout().getPointerSizeInBits(); + I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::ADD_I64 + : WebAssembly::ADD_I32)); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + case G_PTRTOINT: { + assert(MRI.getType(I.getOperand(1).getReg()).isPointer() && + "G_PTRTOINT selection fell-through with non-pointer?"); + + auto PointerWidth = MF.getDataLayout().getPointerSizeInBits(); + I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::COPY_I64 + : WebAssembly::COPY_I32)); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + case G_ICMP: { + Register LHS = I.getOperand(2).getReg(); + Register RHS = I.getOperand(3).getReg(); + CmpInst::Predicate Cond = + static_cast(I.getOperand(1).getPredicate()); + + auto CmpWidth = MRI.getType(LHS).getSizeInBits(); + assert(CmpWidth == MRI.getType(RHS).getSizeInBits() && + "LHS and RHS for ICMP are diffrent lengths???"); + + auto IsI64 = CmpWidth == 64; + + unsigned int CmpOpcode; + switch (Cond) { + case CmpInst::ICMP_EQ: + CmpOpcode = IsI64 ? WebAssembly::EQ_I64 : WebAssembly::EQ_I32; + break; + case CmpInst::ICMP_NE: + CmpOpcode = IsI64 ? WebAssembly::NE_I64 : WebAssembly::NE_I32; + break; + case CmpInst::ICMP_UGT: + CmpOpcode = IsI64 ? WebAssembly::GT_U_I64 : WebAssembly::GT_U_I32; + break; + case CmpInst::ICMP_UGE: + CmpOpcode = IsI64 ? WebAssembly::GE_U_I64 : WebAssembly::GE_U_I32; + break; + case CmpInst::ICMP_ULT: + CmpOpcode = IsI64 ? WebAssembly::LT_U_I64 : WebAssembly::LT_U_I32; + break; + case CmpInst::ICMP_ULE: + CmpOpcode = IsI64 ? WebAssembly::LE_U_I64 : WebAssembly::LE_U_I32; + break; + case CmpInst::ICMP_SGT: + CmpOpcode = IsI64 ? WebAssembly::GT_S_I64 : WebAssembly::GT_S_I32; + break; + case CmpInst::ICMP_SGE: + CmpOpcode = IsI64 ? WebAssembly::GE_S_I64 : WebAssembly::GE_S_I32; + break; + case CmpInst::ICMP_SLT: + CmpOpcode = IsI64 ? WebAssembly::LT_S_I64 : WebAssembly::LT_S_I32; + break; + case CmpInst::ICMP_SLE: + CmpOpcode = IsI64 ? WebAssembly::LE_S_I64 : WebAssembly::LE_S_I32; + break; + default: + llvm_unreachable("Unknown ICMP predicate"); + } + + I.setDesc(TII.get(CmpOpcode)); + I.removeOperand(1); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + case G_FRAME_INDEX: { + MachineIRBuilder B(I); + + auto MIB = B.buildInstr(PtrIsI64 ? WebAssembly::ADD_I64 : WebAssembly::ADD_I32) + .addDef(I.getOperand(0).getReg()) + .addReg(PtrIsI64 ? WebAssembly::SP64 : WebAssembly::SP32) + .addFrameIndex(I.getOperand(1).getIndex()); + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + I.eraseFromParent(); + return true; + } + case G_GLOBAL_VALUE: + assert(I.getOperand(1).getTargetFlags() == 0 && + "Unexpected target flags on generic G_GLOBAL_VALUE instruction"); + assert(WebAssembly::isValidAddressSpace( + MRI.getType(I.getOperand(0).getReg()).getAddressSpace()) && + "Invalid address space for WebAssembly target"); + + unsigned OperandFlags = 0; + const llvm::GlobalValue *GV = I.getOperand(1).getGlobal(); + // Since WebAssembly tables cannot yet be shared accross modules, we don't + // need special treatment for tables in PIC mode. + if (TLI.isPositionIndependent() && + !WebAssembly::isWebAssemblyTableType(GV->getValueType())) { + if (TM.shouldAssumeDSOLocal(GV)) { + const char *BaseName; + if (GV->getValueType()->isFunctionTy()) { + BaseName = MF.createExternalSymbolName("__table_base"); + OperandFlags = WebAssemblyII::MO_TABLE_BASE_REL; + } else { + BaseName = MF.createExternalSymbolName("__memory_base"); + OperandFlags = WebAssemblyII::MO_MEMORY_BASE_REL; + } + MachineIRBuilder B(I); + + auto MemBase = MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); + MRI.setRegClass(MemBase, PtrIsI64 ? &WebAssembly::I64RegClass : &WebAssembly::I32RegClass); + auto Offset = MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); + MRI.setRegClass(Offset, PtrIsI64 ? &WebAssembly::I64RegClass : &WebAssembly::I32RegClass); + + B.buildInstr(PtrIsI64 ? WebAssembly::GLOBAL_GET_I64 + : WebAssembly::GLOBAL_GET_I32) + .addDef(MemBase) + .addExternalSymbol(BaseName); + + B.buildInstr(PtrIsI64 ? WebAssembly::CONST_I64 : WebAssembly::CONST_I32) + .addDef(Offset) + .addGlobalAddress(GV, I.getOperand(1).getOffset(), OperandFlags); + + auto MIB = + B.buildInstr(PtrIsI64 ? WebAssembly::ADD_I64 : WebAssembly::ADD_I32) + .addDef(I.getOperand(0).getReg()) + .addReg(MemBase) + .addReg(Offset); + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + I.eraseFromParent(); + return true; + } + OperandFlags = WebAssemblyII::MO_GOT; + } + + auto NewOpc = MF.getDataLayout().getPointerSizeInBits() == 64 + ? WebAssembly::CONST_I64 + : WebAssembly::CONST_I32; + + if (OperandFlags & WebAssemblyII::MO_GOT) { + NewOpc = MF.getDataLayout().getPointerSizeInBits() == 64 + ? WebAssembly::GLOBAL_GET_I64 + : WebAssembly::GLOBAL_GET_I32; + } + + I.setDesc(TII.get(NewOpc)); + I.getOperand(1).setTargetFlags(OperandFlags); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return true; + } + + return false; +} + +namespace llvm { +InstructionSelector * +createWebAssemblyInstructionSelector(const WebAssemblyTargetMachine &TM, + const WebAssemblySubtarget &Subtarget, + const WebAssemblyRegisterBankInfo &RBI) { + return new WebAssemblyInstructionSelector(TM, Subtarget, RBI); +} +} // namespace llvm diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.h deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h index 2dbd597f01cc9..0c56c5077c563 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -15,6 +15,9 @@ #ifndef LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLY_H #define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLY_H +#include "GISel/WebAssemblyRegisterBankInfo.h" +#include "WebAssemblySubtarget.h" +#include "llvm/CodeGen/GlobalISel/InstructionSelector.h" #include "llvm/PassRegistry.h" #include "llvm/Support/CodeGen.h" @@ -32,6 +35,12 @@ FunctionPass *createWebAssemblyOptimizeReturned(); FunctionPass *createWebAssemblyLowerRefTypesIntPtrConv(); FunctionPass *createWebAssemblyRefTypeMem2Local(); +// GlobalISel +InstructionSelector * +createWebAssemblyInstructionSelector(const WebAssemblyTargetMachine &, + const WebAssemblySubtarget &, + const WebAssemblyRegisterBankInfo &); + // ISel and immediate followup passes. FunctionPass *createWebAssemblyISelDag(WebAssemblyTargetMachine &TM, CodeGenOptLevel OptLevel); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td index 0cbe9d0c6a6a4..5d5af1b7a61bd 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td @@ -33,6 +33,12 @@ def AddrOps32 : ComplexPattern; def AddrOps64 : ComplexPattern; +def gi_AddrOps32 : GIComplexOperandMatcher, + GIComplexPatternEquiv; + +def gi_AddrOps64 : GIComplexOperandMatcher, + GIComplexPatternEquiv; + // Defines atomic and non-atomic loads, regular and extending. multiclass WebAssemblyLoad reqs = []> { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp index b99c35acabef6..315cbb65371a0 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.cpp @@ -12,6 +12,7 @@ /// //===----------------------------------------------------------------------===// +#include "WebAssembly.h" #include "WebAssemblySubtarget.h" #include "GISel/WebAssemblyCallLowering.h" #include "GISel/WebAssemblyLegalizerInfo.h" @@ -75,9 +76,9 @@ WebAssemblySubtarget::WebAssemblySubtarget(const Triple &TT, Legalizer.reset(new WebAssemblyLegalizerInfo(*this)); auto *RBI = new WebAssemblyRegisterBankInfo(*getRegisterInfo()); RegBankInfo.reset(RBI); -/* + InstSelector.reset(createWebAssemblyInstructionSelector( - *static_cast(&TM), *this, *RBI));*/ + *static_cast(&TM), *this, *RBI)); } bool WebAssemblySubtarget::enableAtomicExpand() const { From baae9d5ff1a896b60356d7912ef86af29c98c3f7 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Fri, 3 Oct 2025 10:41:03 -0700 Subject: [PATCH 11/17] Instruction selection wave 2 --- llvm/lib/Target/WebAssembly/CMakeLists.txt | 4 +- .../GISel/WebAssemblyCallLowering.cpp | 15 +- .../GISel/WebAssemblyInstructionSelector.cpp | 273 +++++++++++------- .../GISel/WebAssemblyLegalizerInfo.cpp | 147 +++++++++- .../GISel/WebAssemblyRegisterBankInfo.cpp | 5 +- .../Target/WebAssembly/WebAssemblyGISel.td | 133 +++++++++ .../WebAssembly/WebAssemblyInstrMemory.td | 6 - .../WebAssembly/WebAssemblyTargetMachine.cpp | 6 + 8 files changed, 470 insertions(+), 119 deletions(-) create mode 100644 llvm/lib/Target/WebAssembly/WebAssemblyGISel.td diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index e80850eb073fb..a5a23d39cf700 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -7,7 +7,6 @@ tablegen(LLVM WebAssemblyGenAsmWriter.inc -gen-asm-writer) tablegen(LLVM WebAssemblyGenDAGISel.inc -gen-dag-isel) tablegen(LLVM WebAssemblyGenDisassemblerTables.inc -gen-disassembler) tablegen(LLVM WebAssemblyGenFastISel.inc -gen-fast-isel) -tablegen(LLVM WebAssemblyGenGlobalISel.inc -gen-global-isel) tablegen(LLVM WebAssemblyGenInstrInfo.inc -gen-instr-info) tablegen(LLVM WebAssemblyGenMCCodeEmitter.inc -gen-emitter) tablegen(LLVM WebAssemblyGenRegisterBank.inc -gen-register-bank) @@ -15,6 +14,9 @@ tablegen(LLVM WebAssemblyGenRegisterInfo.inc -gen-register-info) tablegen(LLVM WebAssemblyGenSDNodeInfo.inc -gen-sd-node-info) tablegen(LLVM WebAssemblyGenSubtargetInfo.inc -gen-subtarget) +set(LLVM_TARGET_DEFINITIONS WebAssemblyGISel.td) +tablegen(LLVM WebAssemblyGenGlobalISel.inc -gen-global-isel) + add_public_tablegen_target(WebAssemblyCommonTableGen) add_llvm_target(WebAssemblyCodeGen diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 733d676ac988a..1990d3e4e3adb 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -527,7 +527,8 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, auto NewOutReg = Arg.Regs[Part]; if (!RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI)) { NewOutReg = MRI.createGenericVirtualRegister(NewLLT); - assert(RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI) && "Couldn't constrain brand-new register?"); + assert(RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI) && + "Couldn't constrain brand-new register?"); MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]); } MIB.addUse(NewOutReg); @@ -704,9 +705,12 @@ bool WebAssemblyCallLowering::lowerFormalArguments( getLLTForType(*PointerType::get(Ctx, 0), DL)); MFI->setVarargBufferVreg(VarargVreg); - MIRBuilder.buildInstr(getWASMArgOpcode(PtrVT)) - .addDef(VarargVreg) - .addImm(FinalArgIdx); + auto ArgInst = MIRBuilder.buildInstr(getWASMArgOpcode(PtrVT)) + .addDef(VarargVreg) + .addImm(FinalArgIdx); + + constrainOperandRegClass(MF, TRI, MRI, TII, RBI, *ArgInst, + ArgInst->getDesc(), ArgInst->getOperand(0), 0); MFI->addParam(PtrVT); ++FinalArgIdx; @@ -911,7 +915,8 @@ bool WebAssemblyCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, auto NewRetReg = Ret.Regs[Part]; if (!RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI)) { NewRetReg = MRI.createGenericVirtualRegister(NewLLT); - assert(RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI) && "Couldn't constrain brand-new register?"); + assert(RBI.constrainGenericRegister(NewRetReg, NewRegClass, MRI) && + "Couldn't constrain brand-new register?"); MIRBuilder.buildCopy(NewRetReg, Ret.Regs[Part]); } CallInst.addDef(Ret.Regs[Part]); diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp index aea7b9a424a62..0ef5f357718ac 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyInstructionSelector.cpp @@ -23,10 +23,12 @@ #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" #include "llvm/CodeGen/GlobalISel/Utils.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineJumpTableInfo.h" #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/RegisterBank.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetOpcodes.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/MC/TargetRegistry.h" #include "llvm/Support/ErrorHandling.h" @@ -126,10 +128,10 @@ WebAssemblyInstructionSelector::selectAddrOperands(LLT AddrType, // and offset for an add that needs wrapping. if (RootDef.getFlag(MachineInstr::MIFlag::NoUWrap)) { for (size_t i = 0; i < 2; ++i) { - //MachineOperand &Op = i == 0 ? LHS : RHS; + // MachineOperand &Op = i == 0 ? LHS : RHS; MachineInstr &OpDef = i == 0 ? LHSDef : RHSDef; MachineOperand &OtherOp = i == 0 ? RHS : LHS; - //MachineInstr &OtherOpDef = i == 0 ? RHSDef : LHSDef; + // MachineInstr &OtherOpDef = i == 0 ? RHSDef : LHSDef; if (OpDef.getOpcode() == TargetOpcode::G_CONSTANT) { auto Offset = OpDef.getOperand(1).getCImm()->getZExtValue(); @@ -172,6 +174,23 @@ WebAssemblyInstructionSelector::selectAddrOperands(LLT AddrType, }}; } + if (!TM.isPositionIndependent() && + RootDef.getOpcode() == TargetOpcode::G_GLOBAL_VALUE) { + auto *Offset = RootDef.getOperand(1).getGlobal(); + auto Addr = MRI.createGenericVirtualRegister(AddrType); + + MachineIRBuilder B(RootDef); + + auto MIB = B.buildInstr(ConstOpc).addDef(Addr).addImm(0); + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + "Couldn't constrain registers for instruction"); + + return {{ + [=](MachineInstrBuilder &MIB) { MIB.addGlobalAddress(Offset); }, + [=](MachineInstrBuilder &MIB) { MIB.addReg(Addr); }, + }}; + } + return {{ [=](MachineInstrBuilder &MIB) { MIB.addImm(0); }, [=](MachineInstrBuilder &MIB) { MIB.add(Root); }, @@ -195,29 +214,22 @@ bool WebAssemblyInstructionSelector::selectCopy( Register DstReg = I.getOperand(0).getReg(); Register SrcReg = I.getOperand(1).getReg(); + const TargetRegisterClass *DstRC; if (DstReg.isPhysical()) { - if (DstReg.id() == WebAssembly::SP32) { - if (!RBI.constrainGenericRegister(DstReg, WebAssembly::I32RegClass, - MRI)) { - LLVM_DEBUG(dbgs() << "Failed to constrain " - << TII.getName(I.getOpcode()) << " operand\n"); - return false; - } - return true; - } - if (DstReg.id() == WebAssembly::SP64) { - if (!RBI.constrainGenericRegister(DstReg, WebAssembly::I64RegClass, - MRI)) { - LLVM_DEBUG(dbgs() << "Failed to constrain " - << TII.getName(I.getOpcode()) << " operand\n"); - return false; - } - return true; + switch (DstReg.id()) { + case WebAssembly::SP32: + DstRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::SP64: + DstRC = &WebAssembly::I64RegClass; + break; + default: + llvm_unreachable("Copy to physical register other than SP32 or SP64?"); } - llvm_unreachable("Copy to physical register other than SP32 or SP64?"); + } else { + DstRC = MRI.getRegClassOrNull(DstReg); } - const TargetRegisterClass *DstRC = MRI.getRegClassOrNull(DstReg); if (!DstRC) { const RegisterBank *DstBank = MRI.getRegBankOrNull(DstReg); if (!DstBank) { @@ -240,14 +252,29 @@ bool WebAssemblyInstructionSelector::selectCopy( default: llvm_unreachable("Unknown reg bank to reg class mapping?"); } - if (!RBI.constrainGenericRegister(DstReg, *DstRC, MRI)) { + if (!constrainOperandRegClass(*MF, TRI, MRI, TII, RBI, I, *DstRC, + I.getOperand(0))) { LLVM_DEBUG(dbgs() << "Failed to constrain " << TII.getName(I.getOpcode()) << " operand\n"); return false; } } - const TargetRegisterClass *SrcRC = MRI.getRegClassOrNull(SrcReg); + const TargetRegisterClass *SrcRC; + if (SrcReg.isPhysical()) { + switch (SrcReg.id()) { + case WebAssembly::SP32: + SrcRC = &WebAssembly::I32RegClass; + break; + case WebAssembly::SP64: + SrcRC = &WebAssembly::I64RegClass; + break; + default: + llvm_unreachable("Copy to physical register other than SP32 or SP64?"); + } + } else { + SrcRC = MRI.getRegClassOrNull(SrcReg); + } if (!SrcRC) { const RegisterBank *SrcBank = MRI.getRegBankOrNull(SrcReg); if (!SrcBank) { @@ -270,7 +297,8 @@ bool WebAssemblyInstructionSelector::selectCopy( default: llvm_unreachable("Unknown reg bank to reg class mapping?"); } - if (!RBI.constrainGenericRegister(SrcReg, *SrcRC, MRI)) { + if (!constrainOperandRegClass(*MF, TRI, MRI, TII, RBI, I, *SrcRC, + I.getOperand(1))) { LLVM_DEBUG(dbgs() << "Failed to constrain " << TII.getName(I.getOpcode()) << " operand\n"); return false; @@ -311,12 +339,67 @@ bool WebAssemblyInstructionSelector::selectCopy( return true; } +static const TargetRegisterClass * +getRegClassForTypeOnBank(const RegisterBank &RB) { + switch (RB.getID()) { + case WebAssembly::I32RegBankID: + return &WebAssembly::I32RegClass; + case WebAssembly::I64RegBankID: + return &WebAssembly::I64RegClass; + case WebAssembly::F32RegBankID: + return &WebAssembly::F32RegClass; + case WebAssembly::F64RegBankID: + return &WebAssembly::F64RegClass; + case WebAssembly::EXNREFRegBankID: + return &WebAssembly::EXNREFRegClass; + case WebAssembly::EXTERNREFRegBankID: + return &WebAssembly::EXTERNREFRegClass; + case WebAssembly::FUNCREFRegBankID: + return &WebAssembly::FUNCREFRegClass; + // case WebAssembly::V128RegBankID: + // return &WebAssembly::V128RegClass; + } + + return nullptr; +} + bool WebAssemblyInstructionSelector::select(MachineInstr &I) { MachineBasicBlock &MBB = *I.getParent(); MachineFunction &MF = *MBB.getParent(); MachineRegisterInfo &MRI = MF.getRegInfo(); const TargetLowering &TLI = *STI.getTargetLowering(); + if (!I.isPreISelOpcode() || I.getOpcode() == TargetOpcode::G_PHI) { + if (I.getOpcode() == TargetOpcode::PHI || + I.getOpcode() == TargetOpcode::G_PHI) { + const Register DefReg = I.getOperand(0).getReg(); + const LLT DefTy = MRI.getType(DefReg); + + const RegClassOrRegBank &RegClassOrBank = + MRI.getRegClassOrRegBank(DefReg); + + const TargetRegisterClass *DefRC = + dyn_cast(RegClassOrBank); + + if (!DefRC) { + if (!DefTy.isValid()) { + LLVM_DEBUG(dbgs() << "PHI operand has no type, not a gvreg?\n"); + return false; + } + const RegisterBank &RB = *cast(RegClassOrBank); + DefRC = getRegClassForTypeOnBank(RB); + if (!DefRC) { + LLVM_DEBUG(dbgs() << "PHI operand has unexpected size/bank\n"); + return false; + } + } + + I.setDesc(TII.get(TargetOpcode::PHI)); + + return RBI.constrainGenericRegister(DefReg, *DefRC, MRI) != nullptr; + } + } + if (!isPreISelGenericOpcode(I.getOpcode())) { if (I.isCopy()) return selectCopy(I, MRI); @@ -333,22 +416,61 @@ bool WebAssemblyInstructionSelector::select(MachineInstr &I) { auto PtrIsI64 = PointerWidth == 64; switch (I.getOpcode()) { - case G_CONSTANT: { - assert(MRI.getType(I.getOperand(0).getReg()).isPointer() && - "G_CONSTANT selection fell-through with non-pointer?"); + case G_IMPLICIT_DEF: { + const Register DefReg = I.getOperand(0).getReg(); + const LLT DefTy = MRI.getType(DefReg); - auto OrigImm = I.getOperand(1).getCImm()->getValue(); + const RegClassOrRegBank &RegClassOrBank = MRI.getRegClassOrRegBank(DefReg); - auto MaskedVal = OrigImm.getLoBits(PointerWidth); - assert(MaskedVal.eq(OrigImm) && - "Pointer immediate uses more bits than allowed"); + const TargetRegisterClass *DefRC = + dyn_cast(RegClassOrBank); - I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::CONST_I64 - : WebAssembly::CONST_I32)); - I.removeOperand(1); - I.addOperand(MachineOperand::CreateImm(MaskedVal.getZExtValue())); - assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && + if (!DefRC) { + if (!DefTy.isValid()) { + LLVM_DEBUG( + dbgs() << "IMPLICIT_DEF operand has no type, not a gvreg?\n"); + return false; + } + const RegisterBank &RB = *cast(RegClassOrBank); + DefRC = getRegClassForTypeOnBank(RB); + if (!DefRC) { + LLVM_DEBUG(dbgs() << "IMPLICIT_DEF operand has unexpected size/bank\n"); + return false; + } + } + + I.setDesc(TII.get(TargetOpcode::IMPLICIT_DEF)); + + return RBI.constrainGenericRegister(DefReg, *DefRC, MRI) != nullptr; + return true; + } + case G_BRJT: { + auto JT = I.getOperand(1); + auto Index = I.getOperand(2); + + assert(JT.getTargetFlags() == 0 && "WebAssembly doesn't set target flags"); + + MachineIRBuilder B(I); + + MachineJumpTableInfo *MJTI = MF.getJumpTableInfo(); + const auto &MBBs = MJTI->getJumpTables()[JT.getIndex()].MBBs; + + auto MIB = B.buildInstr(PtrIsI64 ? WebAssembly::BR_TABLE_I64 + : WebAssembly::BR_TABLE_I32) + .add(Index); + + for (auto *MBB : MBBs) + MIB.addMBB(MBB); + + // Add the first MBB as a dummy default target for now. This will be + // replaced with the proper default target (and the preceding range check + // eliminated) if possible by WebAssemblyFixBrTableDefaults. + MIB.addMBB(*MBBs.begin()); + + assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && "Couldn't constrain registers for instruction"); + + I.eraseFromParent(); return true; } case G_PTR_ADD: { @@ -367,7 +489,6 @@ bool WebAssemblyInstructionSelector::select(MachineInstr &I) { assert(MRI.getType(I.getOperand(1).getReg()).isPointer() && "G_PTRTOINT selection fell-through with non-pointer?"); - auto PointerWidth = MF.getDataLayout().getPointerSizeInBits(); I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::COPY_I64 : WebAssembly::COPY_I32)); assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && @@ -375,56 +496,12 @@ bool WebAssemblyInstructionSelector::select(MachineInstr &I) { return true; } - case G_ICMP: { - Register LHS = I.getOperand(2).getReg(); - Register RHS = I.getOperand(3).getReg(); - CmpInst::Predicate Cond = - static_cast(I.getOperand(1).getPredicate()); - - auto CmpWidth = MRI.getType(LHS).getSizeInBits(); - assert(CmpWidth == MRI.getType(RHS).getSizeInBits() && - "LHS and RHS for ICMP are diffrent lengths???"); - - auto IsI64 = CmpWidth == 64; - - unsigned int CmpOpcode; - switch (Cond) { - case CmpInst::ICMP_EQ: - CmpOpcode = IsI64 ? WebAssembly::EQ_I64 : WebAssembly::EQ_I32; - break; - case CmpInst::ICMP_NE: - CmpOpcode = IsI64 ? WebAssembly::NE_I64 : WebAssembly::NE_I32; - break; - case CmpInst::ICMP_UGT: - CmpOpcode = IsI64 ? WebAssembly::GT_U_I64 : WebAssembly::GT_U_I32; - break; - case CmpInst::ICMP_UGE: - CmpOpcode = IsI64 ? WebAssembly::GE_U_I64 : WebAssembly::GE_U_I32; - break; - case CmpInst::ICMP_ULT: - CmpOpcode = IsI64 ? WebAssembly::LT_U_I64 : WebAssembly::LT_U_I32; - break; - case CmpInst::ICMP_ULE: - CmpOpcode = IsI64 ? WebAssembly::LE_U_I64 : WebAssembly::LE_U_I32; - break; - case CmpInst::ICMP_SGT: - CmpOpcode = IsI64 ? WebAssembly::GT_S_I64 : WebAssembly::GT_S_I32; - break; - case CmpInst::ICMP_SGE: - CmpOpcode = IsI64 ? WebAssembly::GE_S_I64 : WebAssembly::GE_S_I32; - break; - case CmpInst::ICMP_SLT: - CmpOpcode = IsI64 ? WebAssembly::LT_S_I64 : WebAssembly::LT_S_I32; - break; - case CmpInst::ICMP_SLE: - CmpOpcode = IsI64 ? WebAssembly::LE_S_I64 : WebAssembly::LE_S_I32; - break; - default: - llvm_unreachable("Unknown ICMP predicate"); - } + case G_INTTOPTR: { + assert(MRI.getType(I.getOperand(0).getReg()).isPointer() && + "G_INTTOPTR selection fell-through with non-pointer?"); - I.setDesc(TII.get(CmpOpcode)); - I.removeOperand(1); + I.setDesc(TII.get(PointerWidth == 64 ? WebAssembly::COPY_I64 + : WebAssembly::COPY_I32)); assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && "Couldn't constrain registers for instruction"); @@ -433,14 +510,10 @@ bool WebAssemblyInstructionSelector::select(MachineInstr &I) { case G_FRAME_INDEX: { MachineIRBuilder B(I); - auto MIB = B.buildInstr(PtrIsI64 ? WebAssembly::ADD_I64 : WebAssembly::ADD_I32) - .addDef(I.getOperand(0).getReg()) - .addReg(PtrIsI64 ? WebAssembly::SP64 : WebAssembly::SP32) - .addFrameIndex(I.getOperand(1).getIndex()); - assert(constrainSelectedInstRegOperands(*MIB, TII, TRI, RBI) && + I.setDesc( + TII.get(PtrIsI64 ? WebAssembly::COPY_I64 : WebAssembly::COPY_I32)); + assert(constrainSelectedInstRegOperands(I, TII, TRI, RBI) && "Couldn't constrain registers for instruction"); - - I.eraseFromParent(); return true; } case G_GLOBAL_VALUE: @@ -467,13 +540,17 @@ bool WebAssemblyInstructionSelector::select(MachineInstr &I) { } MachineIRBuilder B(I); - auto MemBase = MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); - MRI.setRegClass(MemBase, PtrIsI64 ? &WebAssembly::I64RegClass : &WebAssembly::I32RegClass); - auto Offset = MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); - MRI.setRegClass(Offset, PtrIsI64 ? &WebAssembly::I64RegClass : &WebAssembly::I32RegClass); + auto MemBase = + MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); + MRI.setRegClass(MemBase, PtrIsI64 ? &WebAssembly::I64RegClass + : &WebAssembly::I32RegClass); + auto Offset = + MRI.createGenericVirtualRegister(LLT::pointer(0, PointerWidth)); + MRI.setRegClass(Offset, PtrIsI64 ? &WebAssembly::I64RegClass + : &WebAssembly::I32RegClass); B.buildInstr(PtrIsI64 ? WebAssembly::GLOBAL_GET_I64 - : WebAssembly::GLOBAL_GET_I32) + : WebAssembly::GLOBAL_GET_I32) .addDef(MemBase) .addExternalSymbol(BaseName); diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index ae2ac0a512427..633dd48cb3ac6 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -43,9 +43,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_BR).alwaysLegal(); getActionDefinitionsBuilder(G_BRCOND).legalFor({s32}).clampScalar(0, s32, s32); - getActionDefinitionsBuilder(G_BRJT) - .legalFor({{p0, s32}}) - .clampScalar(1, s32, s32); + getActionDefinitionsBuilder(G_BRJT).legalFor({{p0, p0s}}); getActionDefinitionsBuilder(G_SELECT) .legalFor({{s32, s32}, {s64, s32}, {p0, s32}}) @@ -62,7 +60,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( .clampScalar(0, s32, s32); getActionDefinitionsBuilder(G_FCMP) - .legalFor({{s32, s32}, {s32, s64}}) + .customFor({{s32, s32}, {s32, s64}}) .clampScalar(0, s32, s32) .libcall(); @@ -150,8 +148,12 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( .widenScalarToNextPow2(1) .clampScalar(1, s32, s64); - getActionDefinitionsBuilder(G_PTRTOINT).legalFor({{p0s, p0}}); - getActionDefinitionsBuilder(G_INTTOPTR).legalFor({{p0, p0s}}); + getActionDefinitionsBuilder(G_PTRTOINT) + .legalFor({p0s, p0}) + .customForCartesianProduct({s32, s64}, {p0}); + getActionDefinitionsBuilder(G_INTTOPTR) + .legalFor({p0, p0s}) + .customForCartesianProduct({p0}, {s32, s64}); getActionDefinitionsBuilder(G_PTR_ADD).legalFor({{p0, p0s}}); getActionDefinitionsBuilder(G_LOAD) @@ -260,6 +262,139 @@ bool WebAssemblyLegalizerInfo::legalizeCustom( auto &MIRBuilder = Helper.MIRBuilder; switch (MI.getOpcode()) { + case WebAssembly::G_PTRTOINT: { + auto TmpReg = MRI.createGenericVirtualRegister( + LLT::scalar(MIRBuilder.getDataLayout().getPointerSizeInBits(0))); + + MIRBuilder.buildPtrToInt(TmpReg, MI.getOperand(1)); + MIRBuilder.buildAnyExtOrTrunc(MI.getOperand(0), TmpReg); + MI.eraseFromParent(); + return true; + } + case WebAssembly::G_INTTOPTR: { + auto TmpReg = MRI.createGenericVirtualRegister( + LLT::scalar(MIRBuilder.getDataLayout().getPointerSizeInBits(0))); + + MIRBuilder.buildAnyExtOrTrunc(TmpReg, MI.getOperand(1)); + MIRBuilder.buildIntToPtr(MI.getOperand(0), TmpReg); + MI.eraseFromParent(); + return true; + } + case TargetOpcode::G_FCMP: { + Register LHS = MI.getOperand(2).getReg(); + Register RHS = MI.getOperand(3).getReg(); + CmpInst::Predicate Cond = + static_cast(MI.getOperand(1).getPredicate()); + + auto CmpWidth = MRI.getType(LHS).getSizeInBits(); + assert(CmpWidth == MRI.getType(RHS).getSizeInBits() && + "LHS and RHS for FCMP are diffrent lengths???"); + + auto IsI64 = CmpWidth == 64; + + switch (Cond) { + case CmpInst::FCMP_FALSE: + return false; + case CmpInst::FCMP_OEQ: + return true; + case CmpInst::FCMP_OGT: + return true; + case CmpInst::FCMP_OGE: + return true; + case CmpInst::FCMP_OLT: + return true; + case CmpInst::FCMP_OLE: + return true; + case CmpInst::FCMP_ONE: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegC = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, TmpRegA, LHS, RHS); + MIRBuilder.buildFCmp(CmpInst::FCMP_OLT, TmpRegB, LHS, RHS); + MIRBuilder.buildOr(TmpRegC, TmpRegA, TmpRegB); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegC); + break; + } + case CmpInst::FCMP_ORD: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegC = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OEQ, TmpRegA, LHS, LHS); + MIRBuilder.buildFCmp(CmpInst::FCMP_OEQ, TmpRegB, RHS, RHS); + MIRBuilder.buildAnd(TmpRegC, TmpRegA, TmpRegB); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegC); + break; + } + case CmpInst::FCMP_UNO: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegC = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_UNE, TmpRegA, LHS, LHS); + MIRBuilder.buildFCmp(CmpInst::FCMP_UNE, TmpRegB, RHS, RHS); + MIRBuilder.buildOr(TmpRegC, TmpRegA, TmpRegB); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegC); + break; + } + case CmpInst::FCMP_UEQ: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_ONE, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_UGT: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OLE, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_UGE: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OLT, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_ULT: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OGE, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_ULE: { + auto TmpRegA = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto TmpRegB = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::FCMP_OGT, TmpRegA, LHS, RHS); + MIRBuilder.buildNot(TmpRegB, TmpRegA); + MIRBuilder.buildAnyExt(MI.getOperand(0).getReg(), TmpRegB); + break; + } + case CmpInst::FCMP_UNE: + return true; + case CmpInst::FCMP_TRUE: + return false; + default: + llvm_unreachable("Unknown FCMP predicate"); + } + + MI.eraseFromParent(); + + return true; + } case TargetOpcode::G_SEXT_INREG: { assert(MI.getOperand(2).isImm() && "Expected immediate"); diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp index fa4103a8b1b31..096cd2125ec22 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp @@ -304,9 +304,8 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { OperandsMapping = getOperandsMapping({&Op0IntValueMapping, nullptr}); break; case G_BRJT: - OperandsMapping = - getOperandsMapping({&Op0IntValueMapping, nullptr, - &WebAssembly::ValueMappings[WebAssembly::I32Idx]}); + OperandsMapping = getOperandsMapping( + {&Op0IntValueMapping, nullptr, &Pointer0ValueMapping}); break; case COPY: { Register DstReg = MI.getOperand(0).getReg(); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td new file mode 100644 index 0000000000000..331424759f0df --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td @@ -0,0 +1,133 @@ +//===-- WebAssemblyGIsel.td - WASM GlobalISel Patterns -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file contains patterns that are relevant to GlobalISel, including +/// GIComplexOperandMatcher definitions for equivalent SelectionDAG +/// ComplexPatterns. +// +//===----------------------------------------------------------------------===// + +include "WebAssembly.td" + + +//===----------------------------------------------------------------------===// +// Pointer types and related patterns +//===----------------------------------------------------------------------===// + +defvar ModeAddr32 = DefaultMode; +def ModeAddr64 : HwMode<"", [HasAddr64]>; + +def Addr0VT : ValueTypeByHwMode<[ModeAddr32, ModeAddr64], + [i32, i64]>; + +def p0 : PtrValueTypeByHwMode; + +// G_CONSTANT with p0 +def : Pat<(p0 (imm:$addr)), + (CONST_I32 imm:$addr)>, Requires<[HasAddr32]>; +def : Pat<(p0 (imm:$addr)), + (CONST_I64 imm:$addr)>, Requires<[HasAddr64]>; + +// G_LOAD of p0 +def : Pat<(p0 (load (AddrOps32 offset32_op:$offset, I32:$addr))), + (LOAD_I32_A32 0, + offset32_op:$offset, + I32:$addr)>, + Requires<[HasAddr32]>; + +def : Pat<(p0 (load (AddrOps64 offset64_op:$offset, I64:$addr))), + (LOAD_I64_A64 0, + offset64_op:$offset, + I64:$addr)>, + Requires<[HasAddr64]>; + +// G_STORE of p0 +def : Pat<(store p0:$val, (AddrOps32 offset32_op:$offset, I32:$addr)), + (STORE_I32_A32 0, + offset32_op:$offset, + I32:$addr, + p0:$val)>, + Requires<[HasAddr32]>; + +def : Pat<(store p0:$val, (AddrOps64 offset64_op:$offset, I64:$addr)), + (STORE_I64_A64 0, + offset64_op:$offset, + I64:$addr, + p0:$val)>, + Requires<[HasAddr64]>; + +// G_SELECT of p0 +def : Pat<(select I32:$cond, p0:$lhs, p0:$rhs), + (SELECT_I32 I32:$lhs, I32:$rhs, I32:$cond)>, Requires<[HasAddr32]>; +def : Pat<(select I32:$cond, p0:$lhs, p0:$rhs), + (SELECT_I64 I64:$lhs, I64:$rhs, I32:$cond)>, Requires<[HasAddr64]>; + +// ISD::SELECT requires its operand to conform to getBooleanContents, but +// WebAssembly's select interprets any non-zero value as true, so we can fold +// a setne with 0 into a select. +def : Pat<(select (i32 (setne I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I32 I32:$lhs, I32:$rhs, I32:$cond)>, Requires<[HasAddr32]>; +def : Pat<(select (i32 (setne I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I64 I64:$lhs, I64:$rhs, I32:$cond)>, Requires<[HasAddr64]>; + +// And again, this time with seteq instead of setne and the arms reversed. +def : Pat<(select (i32 (seteq I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I32 I32:$rhs, I32:$lhs, I32:$cond)>, Requires<[HasAddr32]>; +def : Pat<(select (i32 (seteq I32:$cond, 0)), p0:$lhs, p0:$rhs), + (SELECT_I64 I64:$rhs, I64:$lhs, I32:$cond)>, Requires<[HasAddr64]>; + + +// G_ICMP between p0 +multiclass ComparisonP0 { + def : Pat<(setcc p0:$lhs, p0:$rhs, cond), + (!cast(Name # "_I32") I32:$lhs, I32:$rhs)>, Requires<[HasAddr32]>; + def : Pat<(setcc p0:$lhs, p0:$rhs, cond), + (!cast(Name # "_I64") I64:$lhs, I64:$rhs)>, Requires<[HasAddr64]>; +} + +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; +defm : ComparisonP0; + +//===----------------------------------------------------------------------===// +// Miscallenous patterns +//===----------------------------------------------------------------------===// + +def : Pat<(i32 (fp_to_sint_sat_gi F32:$src)), (I32_TRUNC_S_SAT_F32 F32:$src)>; +def : Pat<(i32 (fp_to_uint_sat_gi F32:$src)), (I32_TRUNC_U_SAT_F32 F32:$src)>; +def : Pat<(i32 (fp_to_sint_sat_gi F64:$src)), (I32_TRUNC_S_SAT_F64 F64:$src)>; +def : Pat<(i32 (fp_to_uint_sat_gi F64:$src)), (I32_TRUNC_U_SAT_F64 F64:$src)>; +def : Pat<(i64 (fp_to_sint_sat_gi F32:$src)), (I64_TRUNC_S_SAT_F32 F32:$src)>; +def : Pat<(i64 (fp_to_uint_sat_gi F32:$src)), (I64_TRUNC_U_SAT_F32 F32:$src)>; +def : Pat<(i64 (fp_to_sint_sat_gi F64:$src)), (I64_TRUNC_S_SAT_F64 F64:$src)>; +def : Pat<(i64 (fp_to_uint_sat_gi F64:$src)), (I64_TRUNC_U_SAT_F64 F64:$src)>; + +def : GINodeEquiv; + +def : Pat<(i32 (ctlz_zero_undef I32:$src)), (CLZ_I32 I32:$src)>; +def : Pat<(i64 (ctlz_zero_undef I64:$src)), (CLZ_I64 I64:$src)>; +def : Pat<(i32 (cttz_zero_undef I32:$src)), (CTZ_I32 I32:$src)>; +def : Pat<(i64 (cttz_zero_undef I64:$src)), (CTZ_I64 I64:$src)>; + +//===----------------------------------------------------------------------===// +// Complex pattern equivalents +//===----------------------------------------------------------------------===// + +def gi_AddrOps32 : GIComplexOperandMatcher, + GIComplexPatternEquiv; + +def gi_AddrOps64 : GIComplexOperandMatcher, + GIComplexPatternEquiv; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td index 5d5af1b7a61bd..0cbe9d0c6a6a4 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td @@ -33,12 +33,6 @@ def AddrOps32 : ComplexPattern; def AddrOps64 : ComplexPattern; -def gi_AddrOps32 : GIComplexOperandMatcher, - GIComplexPatternEquiv; - -def gi_AddrOps64 : GIComplexOperandMatcher, - GIComplexPatternEquiv; - // Defines atomic and non-atomic loads, regular and extending. multiclass WebAssemblyLoad reqs = []> { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp index 66959f9c2ac43..fdee5728e2aee 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -687,6 +687,12 @@ bool WebAssemblyPassConfig::addRegBankSelect() { bool WebAssemblyPassConfig::addGlobalInstructionSelect() { addPass(new InstructionSelect(getOptLevel())); + + addPass(createWebAssemblyArgumentMove()); + addPass(createWebAssemblySetP2AlignOperands()); + addPass(createWebAssemblyFixBrTableDefaults()); + addPass(createWebAssemblyCleanCodeAfterTrap()); + return false; } From 120f9aa3f305c35e660159ca5142cc29beda18d3 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Mon, 6 Oct 2025 13:43:16 -0700 Subject: [PATCH 12/17] Fix error due to difference HwMode signature --- llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp | 3 --- llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp | 2 -- llvm/lib/Target/WebAssembly/WebAssemblyGISel.td | 2 +- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 1990d3e4e3adb..43ba7b1a983aa 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -440,10 +440,7 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, Register SwiftErrorVReg) const { auto MIB = MIRBuilder.buildInstrNoInsert(WebAssembly::RETURN); MachineFunction &MF = MIRBuilder.getMF(); - auto &TLI = *getTLI(); auto &Subtarget = MF.getSubtarget(); - auto &TRI = *Subtarget.getRegisterInfo(); - auto &TII = *Subtarget.getInstrInfo(); auto &RBI = *Subtarget.getRegBankInfo(); assert(((Val && !VRegs.empty()) || (!Val && VRegs.empty())) && diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index 633dd48cb3ac6..3e9d5957a22bc 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -290,8 +290,6 @@ bool WebAssemblyLegalizerInfo::legalizeCustom( assert(CmpWidth == MRI.getType(RHS).getSizeInBits() && "LHS and RHS for FCMP are diffrent lengths???"); - auto IsI64 = CmpWidth == 64; - switch (Cond) { case CmpInst::FCMP_FALSE: return false; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td index 331424759f0df..5ed2dede7a080 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td @@ -21,7 +21,7 @@ include "WebAssembly.td" //===----------------------------------------------------------------------===// defvar ModeAddr32 = DefaultMode; -def ModeAddr64 : HwMode<"", [HasAddr64]>; +def ModeAddr64 : HwMode<[HasAddr64]>; def Addr0VT : ValueTypeByHwMode<[ModeAddr32, ModeAddr64], [i32, i64]>; From bb523f86121a161e482b2a8e7255fb855f5d5d86 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Thu, 9 Oct 2025 12:15:09 -0700 Subject: [PATCH 13/17] Setup combiners --- llvm/lib/Target/WebAssembly/CMakeLists.txt | 11 +- .../GISel/WebAssemblyCallLowering.cpp | 11 +- .../GISel/WebAssemblyLegalizerInfo.cpp | 2 +- .../WebAssemblyO0PreLegalizerCombiner.cpp | 154 ++++++++++++++++ .../WebAssemblyPostLegalizerCombiner.cpp | 166 +++++++++++++++++ .../GISel/WebAssemblyPreLegalizerCombiner.cpp | 172 ++++++++++++++++++ .../GISel/WebAssemblyRegisterBankInfo.cpp | 4 + llvm/lib/Target/WebAssembly/WebAssembly.h | 9 + .../Target/WebAssembly/WebAssemblyCombine.td | 26 +++ .../Target/WebAssembly/WebAssemblyGISel.td | 1 + .../WebAssembly/WebAssemblyTargetMachine.cpp | 17 ++ 11 files changed, 564 insertions(+), 9 deletions(-) create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyO0PreLegalizerCombiner.cpp create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyPostLegalizerCombiner.cpp create mode 100644 llvm/lib/Target/WebAssembly/GISel/WebAssemblyPreLegalizerCombiner.cpp create mode 100644 llvm/lib/Target/WebAssembly/WebAssemblyCombine.td diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index a5a23d39cf700..a295e2daac20f 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -16,14 +16,23 @@ tablegen(LLVM WebAssemblyGenSubtargetInfo.inc -gen-subtarget) set(LLVM_TARGET_DEFINITIONS WebAssemblyGISel.td) tablegen(LLVM WebAssemblyGenGlobalISel.inc -gen-global-isel) +tablegen(LLVM WebAssemblyGenO0PreLegalizeGICombiner.inc -gen-global-isel-combiner + -combiners="WebAssemblyO0PreLegalizerCombiner") +tablegen(LLVM WebAssemblyGenPreLegalizeGICombiner.inc -gen-global-isel-combiner + -combiners="WebAssemblyPreLegalizerCombiner") +tablegen(LLVM WebAssemblyGenPostLegalizeGICombiner.inc -gen-global-isel-combiner + -combiners="WebAssemblyPostLegalizerCombiner") add_public_tablegen_target(WebAssemblyCommonTableGen) add_llvm_target(WebAssemblyCodeGen GISel/WebAssemblyCallLowering.cpp GISel/WebAssemblyInstructionSelector.cpp - GISel/WebAssemblyRegisterBankInfo.cpp + GISel/WebAssemblyO0PreLegalizerCombiner.cpp + GISel/WebAssemblyPostLegalizerCombiner.cpp + GISel/WebAssemblyPreLegalizerCombiner.cpp GISel/WebAssemblyLegalizerInfo.cpp + GISel/WebAssemblyRegisterBankInfo.cpp WebAssemblyAddMissingPrototypes.cpp WebAssemblyArgumentMove.cpp WebAssemblyAsmPrinter.cpp diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index 43ba7b1a983aa..f852716f86268 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -521,13 +521,10 @@ bool WebAssemblyCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, } for (unsigned Part = 0; Part < NumParts; ++Part) { - auto NewOutReg = Arg.Regs[Part]; - if (!RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI)) { - NewOutReg = MRI.createGenericVirtualRegister(NewLLT); - assert(RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI) && - "Couldn't constrain brand-new register?"); - MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]); - } + auto NewOutReg = MRI.createGenericVirtualRegister(NewLLT); + assert(RBI.constrainGenericRegister(NewOutReg, NewRegClass, MRI) && + "Couldn't constrain brand-new register?"); + MIRBuilder.buildCopy(NewOutReg, Arg.Regs[Part]); MIB.addUse(NewOutReg); } } diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index 3e9d5957a22bc..3f4f318961dbf 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -87,7 +87,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( .clampScalar(0, s32, s64); getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF, - G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP}) + G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_ROTL, G_ROTR}) .legalFor({{s32, s32}, {s64, s64}}) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyO0PreLegalizerCombiner.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyO0PreLegalizerCombiner.cpp new file mode 100644 index 0000000000000..521aa2535e362 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyO0PreLegalizerCombiner.cpp @@ -0,0 +1,154 @@ +//=== WebAssemblyVO0PreLegalizerCombiner.cpp ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass does combining of machine instructions at the generic MI level, +// before the legalizer. +// +//===----------------------------------------------------------------------===// + +#include "WebAssembly.h" +#include "WebAssemblySubtarget.h" +#include "llvm/CodeGen/GlobalISel/Combiner.h" +#include "llvm/CodeGen/GlobalISel/CombinerHelper.h" +#include "llvm/CodeGen/GlobalISel/CombinerInfo.h" +#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" +#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/TargetPassConfig.h" + +#define GET_GICOMBINER_DEPS +#include "WebAssemblyGenO0PreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_DEPS + +#define DEBUG_TYPE "wasm-O0-prelegalizer-combiner" + +using namespace llvm; + +namespace { +#define GET_GICOMBINER_TYPES +#include "WebAssemblyGenO0PreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_TYPES + +class WebAssemblyO0PreLegalizerCombinerImpl : public Combiner { +protected: + const CombinerHelper Helper; + const WebAssemblyO0PreLegalizerCombinerImplRuleConfig &RuleConfig; + const WebAssemblySubtarget &STI; + +public: + WebAssemblyO0PreLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelValueTracking &VT, GISelCSEInfo *CSEInfo, + const WebAssemblyO0PreLegalizerCombinerImplRuleConfig &RuleConfig, + const WebAssemblySubtarget &STI); + + static const char *getName() { return "WebAssemblyO0PreLegalizerCombiner"; } + + bool tryCombineAll(MachineInstr &I) const override; + +private: +#define GET_GICOMBINER_CLASS_MEMBERS +#include "WebAssemblyGenO0PreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CLASS_MEMBERS +}; + +#define GET_GICOMBINER_IMPL +#include "WebAssemblyGenO0PreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_IMPL + +WebAssemblyO0PreLegalizerCombinerImpl::WebAssemblyO0PreLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelValueTracking &VT, GISelCSEInfo *CSEInfo, + const WebAssemblyO0PreLegalizerCombinerImplRuleConfig &RuleConfig, + const WebAssemblySubtarget &STI) + : Combiner(MF, CInfo, TPC, &VT, CSEInfo), + Helper(Observer, B, /*IsPreLegalize*/ true, &VT), RuleConfig(RuleConfig), + STI(STI), +#define GET_GICOMBINER_CONSTRUCTOR_INITS +#include "WebAssemblyGenO0PreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CONSTRUCTOR_INITS +{ +} + +// Pass boilerplate +// ================ + +class WebAssemblyO0PreLegalizerCombiner : public MachineFunctionPass { +public: + static char ID; + + WebAssemblyO0PreLegalizerCombiner(); + + StringRef getPassName() const override { + return "WebAssemblyO0PreLegalizerCombiner"; + } + + bool runOnMachineFunction(MachineFunction &MF) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override; + +private: + WebAssemblyO0PreLegalizerCombinerImplRuleConfig RuleConfig; +}; +} // end anonymous namespace + +void WebAssemblyO0PreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.setPreservesCFG(); + getSelectionDAGFallbackAnalysisUsage(AU); + AU.addRequired(); + AU.addPreserved(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +WebAssemblyO0PreLegalizerCombiner::WebAssemblyO0PreLegalizerCombiner() + : MachineFunctionPass(ID) { + if (!RuleConfig.parseCommandLineOption()) + report_fatal_error("Invalid rule identifier"); +} + +bool WebAssemblyO0PreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { + if (MF.getProperties().hasFailedISel()) + return false; + auto &TPC = getAnalysis(); + + const Function &F = MF.getFunction(); + GISelValueTracking *VT = + &getAnalysis().get(MF); + + const WebAssemblySubtarget &ST = MF.getSubtarget(); + + CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, + /*LegalizerInfo*/ nullptr, /*EnableOpt*/ false, + F.hasOptSize(), F.hasMinSize()); + // Disable fixed-point iteration in the Combiner. This improves compile-time + // at the cost of possibly missing optimizations. See PR#94291 for details. + CInfo.MaxIterations = 1; + + WebAssemblyO0PreLegalizerCombinerImpl Impl(MF, CInfo, &TPC, *VT, + /*CSEInfo*/ nullptr, RuleConfig, ST); + return Impl.combineMachineInstrs(); +} + +char WebAssemblyO0PreLegalizerCombiner::ID = 0; +INITIALIZE_PASS_BEGIN(WebAssemblyO0PreLegalizerCombiner, DEBUG_TYPE, + "Combine WebAssembly machine instrs before legalization", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_DEPENDENCY(GISelValueTrackingAnalysisLegacy) +INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass) +INITIALIZE_PASS_END(WebAssemblyO0PreLegalizerCombiner, DEBUG_TYPE, + "Combine WebAssembly machine instrs before legalization", false, + false) + +FunctionPass *llvm::createWebAssemblyO0PreLegalizerCombiner() { + return new WebAssemblyO0PreLegalizerCombiner(); +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyPostLegalizerCombiner.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyPostLegalizerCombiner.cpp new file mode 100644 index 0000000000000..4bf687fc4785c --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyPostLegalizerCombiner.cpp @@ -0,0 +1,166 @@ +//=== WebAssemblyPostLegalizerCombiner.cpp --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Post-legalization combines on generic MachineInstrs. +/// +/// The combines here must preserve instruction legality. +/// +/// Combines which don't rely on instruction legality should go in the +/// WebAssemblyPreLegalizerCombiner. +/// +//===----------------------------------------------------------------------===// + +#include "WebAssembly.h" +#include "WebAssemblyTargetMachine.h" +#include "llvm/CodeGen/GlobalISel/CSEInfo.h" +#include "llvm/CodeGen/GlobalISel/Combiner.h" +#include "llvm/CodeGen/GlobalISel/CombinerHelper.h" +#include "llvm/CodeGen/GlobalISel/CombinerInfo.h" +#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" +#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/TargetPassConfig.h" + +#define GET_GICOMBINER_DEPS +#include "WebAssemblyGenPostLegalizeGICombiner.inc" +#undef GET_GICOMBINER_DEPS + +#define DEBUG_TYPE "wasm-postlegalizer-combiner" + +using namespace llvm; + +namespace { + +#define GET_GICOMBINER_TYPES +#include "WebAssemblyGenPostLegalizeGICombiner.inc" +#undef GET_GICOMBINER_TYPES + +class WebAssemblyPostLegalizerCombinerImpl : public Combiner { +protected: + const CombinerHelper Helper; + const WebAssemblyPostLegalizerCombinerImplRuleConfig &RuleConfig; + const WebAssemblySubtarget &STI; + +public: + WebAssemblyPostLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelValueTracking &VT, GISelCSEInfo *CSEInfo, + const WebAssemblyPostLegalizerCombinerImplRuleConfig &RuleConfig, + const WebAssemblySubtarget &STI, MachineDominatorTree *MDT, + const LegalizerInfo *LI); + + static const char *getName() { return "WebAssemblyPostLegalizerCombiner"; } + + bool tryCombineAll(MachineInstr &I) const override; + +private: +#define GET_GICOMBINER_CLASS_MEMBERS +#include "WebAssemblyGenPostLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CLASS_MEMBERS +}; + +#define GET_GICOMBINER_IMPL +#include "WebAssemblyGenPostLegalizeGICombiner.inc" +#undef GET_GICOMBINER_IMPL + +WebAssemblyPostLegalizerCombinerImpl::WebAssemblyPostLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelValueTracking &VT, GISelCSEInfo *CSEInfo, + const WebAssemblyPostLegalizerCombinerImplRuleConfig &RuleConfig, + const WebAssemblySubtarget &STI, MachineDominatorTree *MDT, + const LegalizerInfo *LI) + : Combiner(MF, CInfo, TPC, &VT, CSEInfo), + Helper(Observer, B, /*IsPreLegalize*/ false, &VT, MDT, LI), + RuleConfig(RuleConfig), STI(STI), +#define GET_GICOMBINER_CONSTRUCTOR_INITS +#include "WebAssemblyGenPostLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CONSTRUCTOR_INITS +{ +} + +class WebAssemblyPostLegalizerCombiner : public MachineFunctionPass { +public: + static char ID; + + WebAssemblyPostLegalizerCombiner(); + + StringRef getPassName() const override { + return "WebAssemblyPostLegalizerCombiner"; + } + + bool runOnMachineFunction(MachineFunction &MF) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; + +private: + WebAssemblyPostLegalizerCombinerImplRuleConfig RuleConfig; +}; +} // end anonymous namespace + +void WebAssemblyPostLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.setPreservesCFG(); + getSelectionDAGFallbackAnalysisUsage(AU); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +WebAssemblyPostLegalizerCombiner::WebAssemblyPostLegalizerCombiner() + : MachineFunctionPass(ID) { + if (!RuleConfig.parseCommandLineOption()) + report_fatal_error("Invalid rule identifier"); +} + +bool WebAssemblyPostLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { + if (MF.getProperties().hasFailedISel()) + return false; + assert(MF.getProperties().hasLegalized() && "Expected a legalized function?"); + auto *TPC = &getAnalysis(); + const Function &F = MF.getFunction(); + bool EnableOpt = + MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F); + + const WebAssemblySubtarget &ST = MF.getSubtarget(); + const auto *LI = ST.getLegalizerInfo(); + + GISelValueTracking *VT = + &getAnalysis().get(MF); + MachineDominatorTree *MDT = + &getAnalysis().getDomTree(); + GISelCSEAnalysisWrapper &Wrapper = + getAnalysis().getCSEWrapper(); + auto *CSEInfo = &Wrapper.get(TPC->getCSEConfig()); + + CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, + /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(), + F.hasMinSize()); + WebAssemblyPostLegalizerCombinerImpl Impl(MF, CInfo, TPC, *VT, CSEInfo, RuleConfig, + ST, MDT, LI); + return Impl.combineMachineInstrs(); +} + +char WebAssemblyPostLegalizerCombiner::ID = 0; +INITIALIZE_PASS_BEGIN(WebAssemblyPostLegalizerCombiner, DEBUG_TYPE, + "Combine WebAssembly MachineInstrs after legalization", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_DEPENDENCY(GISelValueTrackingAnalysisLegacy) +INITIALIZE_PASS_END(WebAssemblyPostLegalizerCombiner, DEBUG_TYPE, + "Combine WebAssembly MachineInstrs after legalization", false, + false) + +FunctionPass *llvm::createWebAssemblyPostLegalizerCombiner() { + return new WebAssemblyPostLegalizerCombiner(); +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyPreLegalizerCombiner.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyPreLegalizerCombiner.cpp new file mode 100644 index 0000000000000..a939eafd77392 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyPreLegalizerCombiner.cpp @@ -0,0 +1,172 @@ +//=== WebAssemblyPreLegalizerCombiner.cpp ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass does combining of machine instructions at the generic MI level, +// before the legalizer. +// +//===----------------------------------------------------------------------===// + +#include "WebAssembly.h" +#include "WebAssemblySubtarget.h" +#include "llvm/CodeGen/GlobalISel/CSEInfo.h" +#include "llvm/CodeGen/GlobalISel/Combiner.h" +#include "llvm/CodeGen/GlobalISel/CombinerHelper.h" +#include "llvm/CodeGen/GlobalISel/CombinerInfo.h" +#include "llvm/CodeGen/GlobalISel/GIMatchTableExecutorImpl.h" +#include "llvm/CodeGen/GlobalISel/GISelValueTracking.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/Target/TargetMachine.h" + +#define GET_GICOMBINER_DEPS +#include "WebAssemblyGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_DEPS + +#define DEBUG_TYPE "wasm-prelegalizer-combiner" + +using namespace llvm; + +namespace { + +#define GET_GICOMBINER_TYPES +#include "WebAssemblyGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_TYPES + +class WebAssemblyPreLegalizerCombinerImpl : public Combiner { +protected: + const CombinerHelper Helper; + const WebAssemblyPreLegalizerCombinerImplRuleConfig &RuleConfig; + const WebAssemblySubtarget &STI; + +public: + WebAssemblyPreLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelValueTracking &VT, GISelCSEInfo *CSEInfo, + const WebAssemblyPreLegalizerCombinerImplRuleConfig &RuleConfig, + const WebAssemblySubtarget &STI, MachineDominatorTree *MDT, + const LegalizerInfo *LI); + + static const char *getName() { return "WebAssembly00PreLegalizerCombiner"; } + + bool tryCombineAll(MachineInstr &I) const override; + +private: +#define GET_GICOMBINER_CLASS_MEMBERS +#include "WebAssemblyGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CLASS_MEMBERS +}; + +#define GET_GICOMBINER_IMPL +#include "WebAssemblyGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_IMPL + +WebAssemblyPreLegalizerCombinerImpl::WebAssemblyPreLegalizerCombinerImpl( + MachineFunction &MF, CombinerInfo &CInfo, const TargetPassConfig *TPC, + GISelValueTracking &VT, GISelCSEInfo *CSEInfo, + const WebAssemblyPreLegalizerCombinerImplRuleConfig &RuleConfig, + const WebAssemblySubtarget &STI, MachineDominatorTree *MDT, + const LegalizerInfo *LI) + : Combiner(MF, CInfo, TPC, &VT, CSEInfo), + Helper(Observer, B, /*IsPreLegalize*/ true, &VT, MDT, LI), + RuleConfig(RuleConfig), STI(STI), +#define GET_GICOMBINER_CONSTRUCTOR_INITS +#include "WebAssemblyGenPreLegalizeGICombiner.inc" +#undef GET_GICOMBINER_CONSTRUCTOR_INITS +{ +} + +// Pass boilerplate +// ================ + +class WebAssemblyPreLegalizerCombiner : public MachineFunctionPass { +public: + static char ID; + + WebAssemblyPreLegalizerCombiner(); + + StringRef getPassName() const override { return "WebAssemblyPreLegalizerCombiner"; } + + bool runOnMachineFunction(MachineFunction &MF) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override; + +private: + WebAssemblyPreLegalizerCombinerImplRuleConfig RuleConfig; +}; +} // end anonymous namespace + +void WebAssemblyPreLegalizerCombiner::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.setPreservesCFG(); + getSelectionDAGFallbackAnalysisUsage(AU); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +WebAssemblyPreLegalizerCombiner::WebAssemblyPreLegalizerCombiner() + : MachineFunctionPass(ID) { + if (!RuleConfig.parseCommandLineOption()) + report_fatal_error("Invalid rule identifier"); +} + +bool WebAssemblyPreLegalizerCombiner::runOnMachineFunction(MachineFunction &MF) { + if (MF.getProperties().hasFailedISel()) + return false; + auto &TPC = getAnalysis(); + + // Enable CSE. + GISelCSEAnalysisWrapper &Wrapper = + getAnalysis().getCSEWrapper(); + auto *CSEInfo = &Wrapper.get(TPC.getCSEConfig()); + + const WebAssemblySubtarget &ST = MF.getSubtarget(); + const auto *LI = ST.getLegalizerInfo(); + + const Function &F = MF.getFunction(); + bool EnableOpt = + MF.getTarget().getOptLevel() != CodeGenOptLevel::None && !skipFunction(F); + GISelValueTracking *VT = + &getAnalysis().get(MF); + MachineDominatorTree *MDT = + &getAnalysis().getDomTree(); + CombinerInfo CInfo(/*AllowIllegalOps*/ true, /*ShouldLegalizeIllegal*/ false, + /*LegalizerInfo*/ nullptr, EnableOpt, F.hasOptSize(), + F.hasMinSize()); + // Disable fixed-point iteration to reduce compile-time + CInfo.MaxIterations = 1; + CInfo.ObserverLvl = CombinerInfo::ObserverLevel::SinglePass; + // This is the first Combiner, so the input IR might contain dead + // instructions. + CInfo.EnableFullDCE = true; + WebAssemblyPreLegalizerCombinerImpl Impl(MF, CInfo, &TPC, *VT, CSEInfo, RuleConfig, + ST, MDT, LI); + return Impl.combineMachineInstrs(); +} + +char WebAssemblyPreLegalizerCombiner::ID = 0; +INITIALIZE_PASS_BEGIN(WebAssemblyPreLegalizerCombiner, DEBUG_TYPE, + "Combine WebAssembly machine instrs before legalization", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_DEPENDENCY(GISelValueTrackingAnalysisLegacy) +INITIALIZE_PASS_DEPENDENCY(GISelCSEAnalysisWrapperPass) +INITIALIZE_PASS_END(WebAssemblyPreLegalizerCombiner, DEBUG_TYPE, + "Combine WebAssembly machine instrs before legalization", false, + false) + +FunctionPass *llvm::createWebAssemblyPreLegalizerCombiner() { + return new WebAssemblyPreLegalizerCombiner(); +} diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp index 096cd2125ec22..edb217a0c71d6 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp @@ -167,6 +167,8 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case G_CTPOP: case G_FSHL: case G_FSHR: + case G_ROTR: + case G_ROTL: OperandsMapping = &Op0IntValueMapping; break; case G_FADD: @@ -218,6 +220,8 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { break; } case G_LOAD: + case G_ZEXTLOAD: + case G_SEXTLOAD: case G_STORE: if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0) break; diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h index 0c56c5077c563..5901e4b1a47aa 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -41,6 +41,15 @@ createWebAssemblyInstructionSelector(const WebAssemblyTargetMachine &, const WebAssemblySubtarget &, const WebAssemblyRegisterBankInfo &); +FunctionPass *createWebAssemblyPostLegalizerCombiner(); +void initializeWebAssemblyPostLegalizerCombinerPass(PassRegistry &); + +FunctionPass *createWebAssemblyO0PreLegalizerCombiner(); +void initializeWebAssemblyO0PreLegalizerCombinerPass(PassRegistry &); + +FunctionPass *createWebAssemblyPreLegalizerCombiner(); +void initializeWebAssemblyPreLegalizerCombinerPass(PassRegistry &); + // ISel and immediate followup passes. FunctionPass *createWebAssemblyISelDag(WebAssemblyTargetMachine &TM, CodeGenOptLevel OptLevel); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyCombine.td b/llvm/lib/Target/WebAssembly/WebAssemblyCombine.td new file mode 100644 index 0000000000000..d70fd27fa2ad3 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyCombine.td @@ -0,0 +1,26 @@ +//=- WebAssemblyCombine.td - Define WASM Combine Rules -------*- tablegen -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +include "llvm/Target/GlobalISel/Combine.td" + +def WebAssemblyPreLegalizerCombiner: GICombiner< + "WebAssemblyPreLegalizerCombinerImpl", [all_combines]> { +} + +def WebAssemblyO0PreLegalizerCombiner: GICombiner< + "WebAssemblyO0PreLegalizerCombinerImpl", [optnone_combines]> { +} + +// Post-legalization combines which are primarily optimizations. +def WebAssemblyPostLegalizerCombiner + : GICombiner<"WebAssemblyPostLegalizerCombinerImpl", + []> { +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td index 5ed2dede7a080..55656731eaf3e 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// include "WebAssembly.td" +include "WebAssemblyCombine.td" //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp index fdee5728e2aee..39e9c871e21ae 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -97,6 +97,9 @@ LLVMInitializeWebAssemblyTarget() { // Register backend passes auto &PR = *PassRegistry::getPassRegistry(); initializeGlobalISel(PR); + initializeWebAssemblyO0PreLegalizerCombinerPass(PR); + initializeWebAssemblyPreLegalizerCombinerPass(PR); + initializeWebAssemblyPostLegalizerCombinerPass(PR); initializeWebAssemblyAddMissingPrototypesPass(PR); initializeWebAssemblyLowerEmscriptenEHSjLjPass(PR); initializeLowerGlobalDtorsLegacyPassPass(PR); @@ -447,7 +450,9 @@ class WebAssemblyPassConfig final : public TargetPassConfig { bool addRegAssignAndRewriteOptimized() override { return false; } bool addIRTranslator() override; + void addPreLegalizeMachineIR() override; bool addLegalizeMachineIR() override; + void addPreRegBankSelect() override; bool addRegBankSelect() override; bool addGlobalInstructionSelect() override; }; @@ -675,11 +680,23 @@ bool WebAssemblyPassConfig::addIRTranslator() { return false; } +void WebAssemblyPassConfig::addPreLegalizeMachineIR() { + if (getOptLevel() == CodeGenOptLevel::None) { + addPass(createWebAssemblyO0PreLegalizerCombiner()); + } else { + addPass(createWebAssemblyPreLegalizerCombiner()); + } +} bool WebAssemblyPassConfig::addLegalizeMachineIR() { addPass(new Legalizer()); return false; } +void WebAssemblyPassConfig::addPreRegBankSelect() { + if (getOptLevel() != CodeGenOptLevel::None) + addPass(createWebAssemblyPostLegalizerCombiner()); +} + bool WebAssemblyPassConfig::addRegBankSelect() { addPass(new RegBankSelect()); return false; From ce5e73b62e34d4e2789291443f35759ab011c657 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Fri, 10 Oct 2025 00:29:08 -0700 Subject: [PATCH 14/17] Implement more floating-point ops --- .../GISel/WebAssemblyLegalizerInfo.cpp | 75 +++++++++++++++++-- .../GISel/WebAssemblyRegisterBankInfo.cpp | 2 + .../Target/WebAssembly/WebAssemblyGISel.td | 4 + .../WebAssembly/WebAssemblyRegisterBanks.td | 2 +- 4 files changed, 74 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index 3f4f318961dbf..563570a2096fd 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -87,7 +87,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( .clampScalar(0, s32, s64); getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF, - G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_ROTL, G_ROTR}) + G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_ROTL, + G_ROTR}) .legalFor({{s32, s32}, {s64, s64}}) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64) @@ -110,14 +111,16 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder({G_FADD, G_FSUB, G_FDIV, G_FMUL, G_FNEG, G_FABS, G_FCEIL, G_FFLOOR, G_FSQRT, G_INTRINSIC_TRUNC, G_FNEARBYINT, G_FRINT, G_INTRINSIC_ROUNDEVEN, - G_FMINIMUM, G_FMAXIMUM}) + G_FMINIMUM, G_FMAXIMUM, G_STRICT_FMUL}) .legalFor({s32, s64}) .minScalar(0, s32); - // TODO: _IEEE not lowering correctly? - getActionDefinitionsBuilder( - {G_FMINNUM, G_FMAXNUM, G_FMINNUM_IEEE, G_FMAXNUM_IEEE}) - .lowerFor({s32, s64}) + getActionDefinitionsBuilder({G_FMINNUM, G_FMAXNUM}) + .customFor({s32, s64}) + .minScalar(0, s32); + + getActionDefinitionsBuilder(G_FCANONICALIZE) + .customFor({s32, s64}) .minScalar(0, s32); getActionDefinitionsBuilder({G_FMA, G_FREM}) @@ -262,7 +265,63 @@ bool WebAssemblyLegalizerInfo::legalizeCustom( auto &MIRBuilder = Helper.MIRBuilder; switch (MI.getOpcode()) { - case WebAssembly::G_PTRTOINT: { + case TargetOpcode::G_FCANONICALIZE: { + auto One = MRI.createGenericVirtualRegister( + MRI.getType(MI.getOperand(0).getReg())); + MIRBuilder.buildFConstant(One, 1.0); + + MIRBuilder.buildInstr(TargetOpcode::G_STRICT_FMUL) + .addDef(MI.getOperand(0).getReg()) + .addUse(MI.getOperand(1).getReg()) + .addUse(One) + .setMIFlags(MI.getFlags()) + .setMIFlag(MachineInstr::MIFlag::NoFPExcept); + + MI.eraseFromParent(); + return true; + } + case TargetOpcode::G_FMINNUM: { + if (!MI.getFlag(MachineInstr::MIFlag::FmNoNans)) + return false; + + if (MI.getFlag(MachineInstr::MIFlag::FmNsz)) { + MIRBuilder.buildInstr(TargetOpcode::G_FMINIMUM) + .addDef(MI.getOperand(0).getReg()) + .addUse(MI.getOperand(1).getReg()) + .addUse(MI.getOperand(2).getReg()) + .setMIFlags(MI.getFlags()); + } else { + auto TmpReg = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::Predicate::FCMP_OLT, TmpReg, + MI.getOperand(1), MI.getOperand(2)); + MIRBuilder.buildSelect(MI.getOperand(0), TmpReg, MI.getOperand(1), + MI.getOperand(2)); + } + MI.eraseFromParent(); + return true; + } + case TargetOpcode::G_FMAXNUM: { + if (!MI.getFlag(MachineInstr::MIFlag::FmNoNans)) + return false; + if (MI.getFlag(MachineInstr::MIFlag::FmNsz)) { + MIRBuilder.buildInstr(TargetOpcode::G_FMAXIMUM) + .addDef(MI.getOperand(0).getReg()) + .addUse(MI.getOperand(1).getReg()) + .addUse(MI.getOperand(2).getReg()) + .setMIFlags(MI.getFlags()); + } else { + auto TmpReg = MRI.createGenericVirtualRegister(LLT::scalar(1)); + + MIRBuilder.buildFCmp(CmpInst::Predicate::FCMP_OGT, TmpReg, + MI.getOperand(1), MI.getOperand(2)); + MIRBuilder.buildSelect(MI.getOperand(0), TmpReg, MI.getOperand(1), + MI.getOperand(2)); + } + MI.eraseFromParent(); + return true; + } + case TargetOpcode::G_PTRTOINT: { auto TmpReg = MRI.createGenericVirtualRegister( LLT::scalar(MIRBuilder.getDataLayout().getPointerSizeInBits(0))); @@ -271,7 +330,7 @@ bool WebAssemblyLegalizerInfo::legalizeCustom( MI.eraseFromParent(); return true; } - case WebAssembly::G_INTTOPTR: { + case TargetOpcode::G_INTTOPTR: { auto TmpReg = MRI.createGenericVirtualRegister( LLT::scalar(MIRBuilder.getDataLayout().getPointerSizeInBits(0))); diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp index edb217a0c71d6..b09b0f0f4b0ff 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp @@ -193,6 +193,8 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { case G_FMA: case G_FREM: case G_FCOPYSIGN: + case G_FCANONICALIZE: + case G_STRICT_FMUL: OperandsMapping = &Op0FloatValueMapping; break; case G_SEXT_INREG: diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td index 55656731eaf3e..3f70e3d0fc125 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyGISel.td @@ -123,6 +123,10 @@ def : Pat<(i64 (ctlz_zero_undef I64:$src)), (CLZ_I64 I64:$src)>; def : Pat<(i32 (cttz_zero_undef I32:$src)), (CTZ_I32 I32:$src)>; def : Pat<(i64 (cttz_zero_undef I64:$src)), (CTZ_I64 I64:$src)>; + +def : Pat<(f32 (strict_fmul F32:$lhs, F32:$rhs)), (MUL_F32 F32:$lhs, F32:$rhs)>; +def : Pat<(f64 (strict_fmul F64:$lhs, F64:$rhs)), (MUL_F64 F64:$lhs, F64:$rhs)>; + //===----------------------------------------------------------------------===// // Complex pattern equivalents //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td index 9ebece0e0bf09..7a527a321e2b7 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyRegisterBanks.td @@ -12,7 +12,7 @@ def I32RegBank : RegisterBank<"I32RegBank", [I32]>; def I64RegBank : RegisterBank<"I64RegBank", [I64]>; -def F32RegBank : RegisterBank<"F64RegBank", [F32]>; +def F32RegBank : RegisterBank<"F32RegBank", [F32]>; def F64RegBank : RegisterBank<"F64RegBank", [F64]>; def EXTERNREFRegBank : RegisterBank<"EXTERNREFRegBank", [EXTERNREF]>; From 21544ddfc621f651b5f7e8c9ccfcfb66e02dd3c1 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Fri, 10 Oct 2025 13:36:11 -0700 Subject: [PATCH 15/17] Add scalarization for vector ops (fallback when SIMD isn't available) --- .../GISel/WebAssemblyCallLowering.cpp | 7 +- .../GISel/WebAssemblyLegalizerInfo.cpp | 101 ++++++++++++++++-- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp index f852716f86268..5800d4b7fc399 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyCallLowering.cpp @@ -659,8 +659,6 @@ bool WebAssemblyCallLowering::lowerFormalArguments( for (unsigned Part = 0; Part < NumParts; ++Part) { Arg.Regs[Part] = MRI.createGenericVirtualRegister(NewLLT); } - buildCopyFromRegs(MIRBuilder, Arg.OrigRegs, Arg.Regs, OrigLLT, NewLLT, - Arg.Flags[0], Arg.Ty->isFloatingPointTy()); } for (unsigned Part = 0; Part < NumParts; ++Part) { @@ -673,6 +671,11 @@ bool WebAssemblyCallLowering::lowerFormalArguments( MFI->addParam(NewVT); ++FinalArgIdx; } + + if (NumParts != 1 || OrigVT != NewVT) { + buildCopyFromRegs(MIRBuilder, Arg.OrigRegs, Arg.Regs, OrigLLT, NewLLT, + Arg.Flags[0], Arg.Ty->isFloatingPointTy()); + } } /**/ diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index 563570a2096fd..ce92105f5bbd5 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -47,6 +47,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_SELECT) .legalFor({{s32, s32}, {s64, s32}, {p0, s32}}) + .scalarize(0) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64) .clampScalar(1, s32, s32); @@ -55,12 +56,14 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_ICMP) .legalFor({{s32, s32}, {s32, s64}, {s32, p0}}) + .scalarize(0) .widenScalarToNextPow2(1) .clampScalar(1, s32, s64) .clampScalar(0, s32, s32); getActionDefinitionsBuilder(G_FCMP) .customFor({{s32, s32}, {s32, s64}}) + .scalarize(0) .clampScalar(0, s32, s32) .libcall(); @@ -77,32 +80,36 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_IMPLICIT_DEF) .legalFor({s32, s64, p0}) + .scalarize(0) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64); getActionDefinitionsBuilder( {G_ADD, G_SUB, G_MUL, G_UDIV, G_SDIV, G_UREM, G_SREM}) .legalFor({s32, s64}) + .scalarize(0) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64); getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL, G_CTLZ, G_CTLZ_ZERO_UNDEF, - G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP, G_ROTL, - G_ROTR}) + G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTPOP}) .legalFor({{s32, s32}, {s64, s64}}) + .scalarize(0) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64) .minScalarSameAs(1, 0) .maxScalarSameAs(1, 0); - getActionDefinitionsBuilder({G_FSHL, G_FSHR}) + getActionDefinitionsBuilder({G_FSHL, G_FSHR, G_ROTL, G_ROTR}) .legalFor({{s32, s32}, {s64, s64}}) + .scalarize(0) .lower(); getActionDefinitionsBuilder({G_SCMP, G_UCMP}).lower(); getActionDefinitionsBuilder({G_AND, G_OR, G_XOR}) .legalFor({s32, s64}) + .scalarize(0) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64); @@ -113,14 +120,35 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( G_FNEARBYINT, G_FRINT, G_INTRINSIC_ROUNDEVEN, G_FMINIMUM, G_FMAXIMUM, G_STRICT_FMUL}) .legalFor({s32, s64}) + .scalarize(0) .minScalar(0, s32); getActionDefinitionsBuilder({G_FMINNUM, G_FMAXNUM}) .customFor({s32, s64}) + .scalarize(0) .minScalar(0, s32); + getActionDefinitionsBuilder({G_VECREDUCE_OR, G_VECREDUCE_AND}).scalarize(1); + + getActionDefinitionsBuilder(G_BITCAST) + .customIf([=](const LegalityQuery &Query) { + // Handle casts from i1 vectors to scalars. + LLT DstTy = Query.Types[0]; + LLT SrcTy = Query.Types[1]; + return DstTy.isScalar() && SrcTy.isVector() && + SrcTy.getScalarSizeInBits() == 1; + }) + .lowerIf([=](const LegalityQuery &Query) { + return Query.Types[0].isVector() != Query.Types[1].isVector(); + }) + .scalarize(0); + + getActionDefinitionsBuilder(G_MERGE_VALUES) + .lowerFor({{s64, s32}, {s64, s16}, {s64, s8}, {s32, s16}, {s32, s8}}); + getActionDefinitionsBuilder(G_FCANONICALIZE) .customFor({s32, s64}) + .scalarize(0) .minScalar(0, s32); getActionDefinitionsBuilder({G_FMA, G_FREM}) @@ -135,6 +163,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_FCOPYSIGN) .legalFor({s32, s64}) + .scalarize(0) .minScalar(0, s32) .minScalarSameAs(1, 0) .maxScalarSameAs(1, 0); @@ -147,6 +176,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder({G_UITOFP, G_SITOFP}) .legalForCartesianProduct({s32, s64}, {s32, s64}) + .scalarize(0) .minScalar(1, s32) .widenScalarToNextPow2(1) .clampScalar(1, s32, s64); @@ -169,7 +199,8 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( {s64, p0, s16, 1}, {s64, p0, s32, 1}}) .clampScalar(0, s32, s64) - .lowerIfMemSizeNotByteSizePow2(); + .lowerIfMemSizeNotByteSizePow2() + .scalarize(0); getActionDefinitionsBuilder(G_STORE) .legalForTypesWithMemDesc( @@ -181,7 +212,25 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( {s64, p0, s16, 1}, {s64, p0, s32, 1}}) .clampScalar(0, s32, s64) - .lowerIfMemSizeNotByteSizePow2(); + .lowerIf([=](const LegalityQuery &Query) { + return Query.Types[0].isScalar() && + Query.Types[0] != Query.MMODescrs[0].MemoryTy; + }) + .bitcastIf( + [=](const LegalityQuery &Query) { + // Handle stores of i1 vectors. + LLT Ty = Query.Types[0]; + return Ty.isVector() && Ty.getScalarSizeInBits() == 1; + }, + [=](const LegalityQuery &Query) { + const LLT VecTy = Query.Types[0]; + return std::pair(0, LLT::scalar(VecTy.getSizeInBits())); + }) + .scalarize(0); + + getActionDefinitionsBuilder( + {G_SHUFFLE_VECTOR, G_EXTRACT_VECTOR_ELT, G_INSERT_VECTOR_ELT}) + .lower(); getActionDefinitionsBuilder({G_ZEXTLOAD, G_SEXTLOAD}) .legalForTypesWithMemDesc({{s32, p0, s8, 1}, @@ -219,11 +268,13 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_ANYEXT) .legalFor({{s64, s32}}) + .scalarize(0) .clampScalar(0, s32, s64) .clampScalar(1, s32, s64); getActionDefinitionsBuilder({G_SEXT, G_ZEXT}) .legalFor({{s64, s32}}) + .scalarize(0) .clampScalar(0, s32, s64) .clampScalar(1, s32, s64) .lower(); @@ -238,13 +289,14 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_TRUNC) .legalFor({{s32, s64}}) + .scalarize(0) .clampScalar(0, s32, s64) .clampScalar(1, s32, s64) .lower(); - getActionDefinitionsBuilder(G_FPEXT).legalFor({{s64, s32}}); + getActionDefinitionsBuilder(G_FPEXT).legalFor({{s64, s32}}).scalarize(0); - getActionDefinitionsBuilder(G_FPTRUNC).legalFor({{s32, s64}}); + getActionDefinitionsBuilder(G_FPTRUNC).legalFor({{s32, s64}}).scalarize(0); getActionDefinitionsBuilder(G_VASTART).legalFor({p0}); getActionDefinitionsBuilder(G_VAARG) @@ -339,6 +391,41 @@ bool WebAssemblyLegalizerInfo::legalizeCustom( MI.eraseFromParent(); return true; } + case TargetOpcode::G_BITCAST: { + if (MIRBuilder.getMF().getSubtarget().hasSIMD128()) { + return false; + } + + auto [DstReg, DstTy, SrcReg, SrcTy] = MI.getFirst2RegLLTs(); + + if (!DstTy.isScalar() || !SrcTy.isVector() || + SrcTy.getElementType() != LLT::scalar(1)) + return false; + + Register ResultReg = MRI.createGenericVirtualRegister(DstTy); + MIRBuilder.buildConstant(ResultReg, 0); + + for (unsigned i = 0; i < SrcTy.getNumElements(); i++) { + auto Elm = MRI.createGenericVirtualRegister(LLT::scalar(1)); + auto ExtElm = MRI.createGenericVirtualRegister(DstTy); + auto ShiftedElm = MRI.createGenericVirtualRegister(DstTy); + auto Idx = MRI.createGenericVirtualRegister(LLT::scalar(8)); + auto NewResultReg = MRI.createGenericVirtualRegister(DstTy); + + MIRBuilder.buildConstant(Idx, i); + MIRBuilder.buildExtractVectorElement(Elm, SrcReg, Idx); + MIRBuilder.buildZExt(ExtElm, Elm, false); + MIRBuilder.buildShl(ShiftedElm, ExtElm, Idx); + MIRBuilder.buildOr(NewResultReg, ResultReg, ShiftedElm); + + ResultReg = NewResultReg; + } + + MIRBuilder.buildCopy(DstReg, ResultReg); + + MI.eraseFromParent(); + return true; + } case TargetOpcode::G_FCMP: { Register LHS = MI.getOperand(2).getReg(); Register RHS = MI.getOperand(3).getReg(); From 0835d56d9961ae75e95ae750705a1c905a7a60dc Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sat, 11 Oct 2025 12:16:13 -0700 Subject: [PATCH 16/17] Enable regbankselect to choose f32/f64 for ambiguous instructions (e.g. load) --- .../GISel/WebAssemblyLegalizerInfo.cpp | 1 + .../GISel/WebAssemblyRegisterBankInfo.cpp | 203 +++++++++++++++++- .../GISel/WebAssemblyRegisterBankInfo.h | 28 +++ .../WebAssembly/WebAssemblyInstrMemory.td | 24 +-- .../WebAssembly/WebAssemblyInstrSIMD.td | 2 +- 5 files changed, 236 insertions(+), 22 deletions(-) diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp index ce92105f5bbd5..35bc8af7b189a 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyLegalizerInfo.cpp @@ -38,6 +38,7 @@ WebAssemblyLegalizerInfo::WebAssemblyLegalizerInfo( getActionDefinitionsBuilder(G_PHI) .legalFor({p0, s32, s64}) + .scalarize(0) .widenScalarToNextPow2(0) .clampScalar(0, s32, s64); getActionDefinitionsBuilder(G_BR).alwaysLegal(); diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp index b09b0f0f4b0ff..b57c79721c403 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.cpp @@ -61,14 +61,113 @@ using namespace llvm; WebAssemblyRegisterBankInfo::WebAssemblyRegisterBankInfo( const TargetRegisterInfo &TRI) {} +bool WebAssemblyRegisterBankInfo::isPHIWithFPConstraints( + const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, const unsigned Depth) const { + if (!MI.isPHI() || Depth > MaxFPRSearchDepth) + return false; + + return any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()), + [&](const MachineInstr &UseMI) { + if (onlyUsesFP(UseMI, MRI, TRI, Depth + 1)) + return true; + return isPHIWithFPConstraints(UseMI, MRI, TRI, Depth + 1); + }); +} + +bool WebAssemblyRegisterBankInfo::hasFPConstraints( + const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, unsigned Depth) const { + unsigned Op = MI.getOpcode(); + // if (Op == TargetOpcode::G_INTRINSIC && isFPIntrinsic(MRI, MI)) + // return true; + + // Do we have an explicit floating point instruction? + if (isPreISelGenericFloatingPointOpcode(Op)) + return true; + + // No. Check if we have a copy-like instruction. If we do, then we could + // still be fed by floating point instructions. + if (Op != TargetOpcode::COPY && !MI.isPHI() && + !isPreISelGenericOptimizationHint(Op)) + return false; + + // Check if we already know the register bank. + auto *RB = getRegBank(MI.getOperand(0).getReg(), MRI, TRI); + if (RB == &WebAssembly::F32RegBank || RB == &WebAssembly::F64RegBank) + return true; + if (RB == &WebAssembly::I32RegBank || RB == &WebAssembly::I64RegBank) + return false; + + // We don't know anything. + // + // If we have a phi, we may be able to infer that it will be assigned a FPR + // based off of its inputs. + if (!MI.isPHI() || Depth > MaxFPRSearchDepth) + return false; + + return any_of(MI.explicit_uses(), [&](const MachineOperand &Op) { + return Op.isReg() && + onlyDefinesFP(*MRI.getVRegDef(Op.getReg()), MRI, TRI, Depth + 1); + }); +} + +bool WebAssemblyRegisterBankInfo::onlyUsesFP(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, + unsigned Depth) const { + switch (MI.getOpcode()) { + case TargetOpcode::G_FPTOSI: + case TargetOpcode::G_FPTOUI: + case TargetOpcode::G_FPTOSI_SAT: + case TargetOpcode::G_FPTOUI_SAT: + case TargetOpcode::G_FCMP: + case TargetOpcode::G_LROUND: + case TargetOpcode::G_LLROUND: + return true; + default: + break; + } + return hasFPConstraints(MI, MRI, TRI, Depth); +} + +bool WebAssemblyRegisterBankInfo::onlyDefinesFP( + const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, unsigned Depth) const { + switch (MI.getOpcode()) { + case TargetOpcode::G_SITOFP: + case TargetOpcode::G_UITOFP: + case TargetOpcode::G_EXTRACT_VECTOR_ELT: + case TargetOpcode::G_INSERT_VECTOR_ELT: + case TargetOpcode::G_BUILD_VECTOR: + case TargetOpcode::G_BUILD_VECTOR_TRUNC: + return true; + default: + break; + } + return hasFPConstraints(MI, MRI, TRI, Depth); +} + +bool WebAssemblyRegisterBankInfo::prefersFPUse( + const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, unsigned Depth) const { + switch (MI.getOpcode()) { + case TargetOpcode::G_SITOFP: + case TargetOpcode::G_UITOFP: + return MRI.getType(MI.getOperand(0).getReg()).getSizeInBits() == + MRI.getType(MI.getOperand(1).getReg()).getSizeInBits(); + } + return onlyDefinesFP(MI, MRI, TRI, Depth); +} + const RegisterBankInfo::InstructionMapping & WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { unsigned Opc = MI.getOpcode(); const MachineFunction &MF = *MI.getParent()->getParent(); const MachineRegisterInfo &MRI = MF.getRegInfo(); - const TargetSubtargetInfo &STI = MF.getSubtarget(); - const TargetRegisterInfo &TRI = *STI.getRegisterInfo(); + const WebAssemblySubtarget &STI = MF.getSubtarget(); + const WebAssemblyRegisterInfo &TRI = *STI.getRegisterInfo(); if ((Opc != TargetOpcode::COPY && !isPreISelGenericOpcode(Opc)) || Opc == TargetOpcode::G_PHI) { @@ -223,13 +322,50 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { } case G_LOAD: case G_ZEXTLOAD: - case G_SEXTLOAD: - case G_STORE: + case G_SEXTLOAD: { if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0) break; + + auto *LoadValueMapping = &Op0IntValueMapping; + if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()), + [&](const MachineInstr &UseMI) { + // If we have at least one direct or indirect use + // in a FP instruction, + // assume this was a floating point load in the IR. If it was + // not, we would have had a bitcast before reaching that + // instruction. + // + // Int->FP conversion operations are also captured in + // prefersFPUse(). + + if (isPHIWithFPConstraints(UseMI, MRI, TRI)) + return true; + + return onlyUsesFP(UseMI, MRI, TRI) || + prefersFPUse(UseMI, MRI, TRI); + })) + LoadValueMapping = &Op0FloatValueMapping; OperandsMapping = - getOperandsMapping({&Op0IntValueMapping, &Pointer0ValueMapping}); + getOperandsMapping({LoadValueMapping, &Pointer0ValueMapping}); + break; + } + case G_STORE: { + if (MRI.getType(MI.getOperand(1).getReg()).getAddressSpace() != 0) + break; + + Register VReg = MI.getOperand(0).getReg(); + if (!VReg) + break; + MachineInstr *DefMI = MRI.getVRegDef(VReg); + if (onlyDefinesFP(*DefMI, MRI, TRI)) { + OperandsMapping = + getOperandsMapping({&Op0FloatValueMapping, &Pointer0ValueMapping}); + } else { + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, &Pointer0ValueMapping}); + } break; + } case G_MEMCPY: case G_MEMMOVE: { if (MRI.getType(MI.getOperand(0).getReg()).getAddressSpace() != 0) @@ -375,11 +511,60 @@ WebAssemblyRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { // We only care about the mapping of the destination for COPY. 1); } - case G_SELECT: - OperandsMapping = getOperandsMapping( - {&Op0IntValueMapping, &WebAssembly::ValueMappings[WebAssembly::I32Idx], - &Op0IntValueMapping, &Op0IntValueMapping}); + case G_SELECT: { + // Try to minimize the number of copies. If we have more floating point + // constrained values than not, then we'll put everything on FPR. Otherwise, + // everything has to be on GPR. + unsigned NumFP = 0; + + // Check if the uses of the result always produce floating point values. + // + // For example: + // + // %z = G_SELECT %cond %x %y + // fpr = G_FOO %z ... + if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()), + [&](MachineInstr &MI) { return onlyUsesFP(MI, MRI, TRI); })) + ++NumFP; + + // Check if the defs of the source values always produce floating point + // values. + // + // For example: + // + // %x = G_SOMETHING_ALWAYS_FLOAT %a ... + // %z = G_SELECT %cond %x %y + // + // Also check whether or not the sources have already been decided to be + // FPR. Keep track of this. + // + // This doesn't check the condition, since it's just whatever is in NZCV. + // This isn't passed explicitly in a register to fcsel/csel. + for (unsigned Idx = 2; Idx < 4; ++Idx) { + Register VReg = MI.getOperand(Idx).getReg(); + MachineInstr *DefMI = MRI.getVRegDef(VReg); + if (getRegBank(VReg, MRI, TRI) == &WebAssembly::F32RegBank || + getRegBank(VReg, MRI, TRI) == &WebAssembly::F64RegBank || + onlyDefinesFP(*DefMI, MRI, TRI)) + ++NumFP; + } + + // If we have more FP constraints than not, then move everything over to + // FPR. + if (NumFP >= 2) { + OperandsMapping = + getOperandsMapping({&Op0FloatValueMapping, + &WebAssembly::ValueMappings[WebAssembly::I32Idx], + &Op0FloatValueMapping, &Op0FloatValueMapping}); + + } else { + OperandsMapping = + getOperandsMapping({&Op0IntValueMapping, + &WebAssembly::ValueMappings[WebAssembly::I32Idx], + &Op0IntValueMapping, &Op0IntValueMapping}); + } break; + } case G_FPTOSI: case G_FPTOSI_SAT: case G_FPTOUI: diff --git a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h index f0d95b56ef861..d2cde32cff45e 100644 --- a/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h +++ b/llvm/lib/Target/WebAssembly/GISel/WebAssemblyRegisterBankInfo.h @@ -13,6 +13,7 @@ #ifndef LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H #define LLVM_LIB_TARGET_WEBASSEMBLY_WEBASSEMBLYREGISTERBANKINFO_H +#include "WebAssemblyRegisterInfo.h" #include "llvm/CodeGen/RegisterBankInfo.h" #define GET_REGBANK_DECLARATIONS @@ -35,6 +36,33 @@ class WebAssemblyRegisterBankInfo final const InstructionMapping & getInstrMapping(const MachineInstr &MI) const override; + + /// Maximum recursion depth for hasFPConstraints. + const unsigned MaxFPRSearchDepth = 2; + + /// \returns true if \p MI is a PHI that its def is used by + /// any instruction that onlyUsesFP. + bool isPHIWithFPConstraints(const MachineInstr &MI, + const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, + unsigned Depth = 0) const; + + /// \returns true if \p MI only uses and defines FPRs. + bool hasFPConstraints(const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, + unsigned Depth = 0) const; + + /// \returns true if \p MI only uses FPRs. + bool onlyUsesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, unsigned Depth = 0) const; + + /// \returns true if \p MI only defines FPRs. + bool onlyDefinesFP(const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, unsigned Depth = 0) const; + + /// \returns true if \p MI can take both fpr and gpr uses, but prefers fp. + bool prefersFPUse(const MachineInstr &MI, const MachineRegisterInfo &MRI, + const WebAssemblyRegisterInfo &TRI, unsigned Depth = 0) const; }; } // end namespace llvm #endif diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td index 0cbe9d0c6a6a4..85fb6812769f4 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrMemory.td @@ -145,8 +145,8 @@ defm STORE_I64 : WebAssemblyStore; defm STORE_F32 : WebAssemblyStore; defm STORE_F64 : WebAssemblyStore; -multiclass StorePat { - def : Pat<(kind ty:$val, (AddrOps32 offset32_op:$offset, I32:$addr)), +multiclass StorePat { + def : Pat<(kind (ty rc:$val), (AddrOps32 offset32_op:$offset, I32:$addr)), (!cast(Name # "_A32") 0, offset32_op:$offset, I32:$addr, @@ -160,10 +160,10 @@ multiclass StorePat { Requires<[HasAddr64]>; } -defm : StorePat; -defm : StorePat; -defm : StorePat; -defm : StorePat; +defm : StorePat; +defm : StorePat; +defm : StorePat; +defm : StorePat; // Truncating store. defm STORE8_I32 : WebAssemblyStore; @@ -176,13 +176,13 @@ defm STORE32_I64 : WebAssemblyStore; defm STORE_F16_F32 : WebAssemblyStore; -defm : StorePat; -defm : StorePat; -defm : StorePat; -defm : StorePat; -defm : StorePat; +defm : StorePat; +defm : StorePat; +defm : StorePat; +defm : StorePat; +defm : StorePat; -defm : StorePat; +defm : StorePat; multiclass MemoryOps { // Current memory size. diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td index 14097d7b40a9c..0d3c5eab1adf6 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td @@ -391,7 +391,7 @@ defm STORE_V128_A64 : // Def store patterns from WebAssemblyInstrMemory.td for vector types foreach vec = AllVecs in { -defm : StorePat; +defm : StorePat; } // Store lane From 42f62ef69e4630ef256a491666071c8b2c9394e8 Mon Sep 17 00:00:00 2001 From: Demetrius Kanios Date: Sat, 11 Oct 2025 12:41:21 -0700 Subject: [PATCH 17/17] Ensure GlobalISel is actually linked in --- llvm/lib/Target/WebAssembly/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index a295e2daac20f..217df5f987e85 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -89,6 +89,7 @@ add_llvm_target(WebAssemblyCodeGen CodeGen CodeGenTypes Core + GlobalISel MC Scalar SelectionDAG