diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst index 33c23f2949765..7cd7e815fd8d7 100644 --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -3234,6 +3234,22 @@ A "convergencectrl" operand bundle is only valid on a ``convergent`` operation. When present, the operand bundle must contain exactly one value of token type. See the :doc:`ConvergentOperations` document for details. +Deactivation Symbol Operand Bundles +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A ``"deactivation-symbol"`` operand bundle is valid on the following +instructions (AArch64 only): + +- Call to a normal function with ``notail`` attribute and a first argument and + return value of type ``ptr``. +- Call to ``llvm.ptrauth.sign`` or ``llvm.ptrauth.auth`` intrinsics. + +This operand bundle specifies that if the deactivation symbol is defined +to a valid value for the target, the marked instruction will return the +value of its first argument instead of calling the specified function +or intrinsic. This is achieved with ``PATCHINST`` relocations on the +target instructions (see the AArch64 psABI for details). + .. _moduleasm: Module-Level Inline Assembly diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h index a8bde824527a5..fea900f37ec74 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h @@ -159,6 +159,8 @@ class LLVM_ABI CallLowering { /// True if this call results in convergent operations. bool IsConvergent = true; + + GlobalValue *DeactivationSymbol = nullptr; }; /// Argument handling is mostly uniform between the four places that diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h index 40c7792f7e8a2..5f3f1d386569c 100644 --- a/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h @@ -56,6 +56,7 @@ struct MachineIRBuilderState { MDNode *PCSections = nullptr; /// MMRA Metadata to be set on any instruction we create. MDNode *MMRA = nullptr; + Value *DS = nullptr; /// \name Fields describing the insertion point. /// @{ @@ -369,6 +370,7 @@ class LLVM_ABI MachineIRBuilder { State.II = MI.getIterator(); setPCSections(MI.getPCSections()); setMMRAMetadata(MI.getMMRAMetadata()); + setDeactivationSymbol(MI.getDeactivationSymbol()); } /// @} @@ -405,6 +407,9 @@ class LLVM_ABI MachineIRBuilder { /// Set the PC sections metadata to \p MD for all the next build instructions. void setMMRAMetadata(MDNode *MMRA) { State.MMRA = MMRA; } + Value *getDeactivationSymbol() { return State.DS; } + void setDeactivationSymbol(Value *DS) { State.DS = DS; } + /// Get the current instruction's MMRA metadata. MDNode *getMMRAMetadata() { return State.MMRA; } diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h index cdaa916548c25..b32f3dacbb3a4 100644 --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -1579,6 +1579,10 @@ enum NodeType { // Outputs: Output Chain CLEAR_CACHE, + // Untyped node storing deactivation symbol reference + // (DeactivationSymbolSDNode). + DEACTIVATION_SYMBOL, + /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. BUILTIN_OP_END diff --git a/llvm/include/llvm/CodeGen/MachineFunction.h b/llvm/include/llvm/CodeGen/MachineFunction.h index ef783f276b7d4..08ffdb2cb469d 100644 --- a/llvm/include/llvm/CodeGen/MachineFunction.h +++ b/llvm/include/llvm/CodeGen/MachineFunction.h @@ -1207,7 +1207,7 @@ class LLVM_ABI MachineFunction { ArrayRef MMOs, MCSymbol *PreInstrSymbol = nullptr, MCSymbol *PostInstrSymbol = nullptr, MDNode *HeapAllocMarker = nullptr, MDNode *PCSections = nullptr, uint32_t CFIType = 0, - MDNode *MMRAs = nullptr); + MDNode *MMRAs = nullptr, Value *DS = nullptr); /// Allocate a string and populate it with the given external symbol name. const char *createExternalSymbolName(StringRef Name); diff --git a/llvm/include/llvm/CodeGen/MachineInstr.h b/llvm/include/llvm/CodeGen/MachineInstr.h index ca984459c365a..077e39b49df6f 100644 --- a/llvm/include/llvm/CodeGen/MachineInstr.h +++ b/llvm/include/llvm/CodeGen/MachineInstr.h @@ -160,8 +160,9 @@ class MachineInstr /// /// This has to be defined eagerly due to the implementation constraints of /// `PointerSumType` where it is used. - class ExtraInfo final : TrailingObjects { + class ExtraInfo final + : TrailingObjects { public: static ExtraInfo *create(BumpPtrAllocator &Allocator, ArrayRef MMOs, @@ -169,20 +170,23 @@ class MachineInstr MCSymbol *PostInstrSymbol = nullptr, MDNode *HeapAllocMarker = nullptr, MDNode *PCSections = nullptr, uint32_t CFIType = 0, - MDNode *MMRAs = nullptr) { + MDNode *MMRAs = nullptr, Value *DS = nullptr) { bool HasPreInstrSymbol = PreInstrSymbol != nullptr; bool HasPostInstrSymbol = PostInstrSymbol != nullptr; bool HasHeapAllocMarker = HeapAllocMarker != nullptr; bool HasMMRAs = MMRAs != nullptr; bool HasCFIType = CFIType != 0; bool HasPCSections = PCSections != nullptr; + bool HasDS = DS != nullptr; auto *Result = new (Allocator.Allocate( - totalSizeToAlloc( + totalSizeToAlloc( MMOs.size(), HasPreInstrSymbol + HasPostInstrSymbol, - HasHeapAllocMarker + HasPCSections + HasMMRAs, HasCFIType), + HasHeapAllocMarker + HasPCSections + HasMMRAs, HasCFIType, HasDS), alignof(ExtraInfo))) ExtraInfo(MMOs.size(), HasPreInstrSymbol, HasPostInstrSymbol, - HasHeapAllocMarker, HasPCSections, HasCFIType, HasMMRAs); + HasHeapAllocMarker, HasPCSections, HasCFIType, HasMMRAs, + HasDS); // Copy the actual data into the trailing objects. llvm::copy(MMOs, Result->getTrailingObjects()); @@ -202,6 +206,8 @@ class MachineInstr Result->getTrailingObjects()[0] = CFIType; if (HasMMRAs) Result->getTrailingObjects()[MDNodeIdx++] = MMRAs; + if (HasDS) + Result->getTrailingObjects()[0] = DS; return Result; } @@ -240,6 +246,10 @@ class MachineInstr : nullptr; } + Value *getDeactivationSymbol() const { + return HasDS ? getTrailingObjects()[0] : 0; + } + private: friend TrailingObjects; @@ -255,6 +265,7 @@ class MachineInstr const bool HasPCSections; const bool HasCFIType; const bool HasMMRAs; + const bool HasDS; // Implement the `TrailingObjects` internal API. size_t numTrailingObjects(OverloadToken) const { @@ -269,16 +280,17 @@ class MachineInstr size_t numTrailingObjects(OverloadToken) const { return HasCFIType; } + size_t numTrailingObjects(OverloadToken) const { return HasDS; } // Just a boring constructor to allow us to initialize the sizes. Always use // the `create` routine above. ExtraInfo(int NumMMOs, bool HasPreInstrSymbol, bool HasPostInstrSymbol, bool HasHeapAllocMarker, bool HasPCSections, bool HasCFIType, - bool HasMMRAs) + bool HasMMRAs, bool HasDS) : NumMMOs(NumMMOs), HasPreInstrSymbol(HasPreInstrSymbol), HasPostInstrSymbol(HasPostInstrSymbol), HasHeapAllocMarker(HasHeapAllocMarker), HasPCSections(HasPCSections), - HasCFIType(HasCFIType), HasMMRAs(HasMMRAs) {} + HasCFIType(HasCFIType), HasMMRAs(HasMMRAs), HasDS(HasDS) {} }; /// Enumeration of the kinds of inline extra info available. It is important @@ -867,6 +879,14 @@ class MachineInstr return nullptr; } + Value *getDeactivationSymbol() const { + if (!Info) + return nullptr; + if (ExtraInfo *EI = Info.get()) + return EI->getDeactivationSymbol(); + return nullptr; + } + /// Helper to extract a CFI type hash if one has been added. uint32_t getCFIType() const { if (!Info) @@ -1969,6 +1989,8 @@ class MachineInstr /// Set the CFI type for the instruction. LLVM_ABI void setCFIType(MachineFunction &MF, uint32_t Type); + LLVM_ABI void setDeactivationSymbol(MachineFunction &MF, Value *DS); + /// Return the MIFlags which represent both MachineInstrs. This /// should be used when merging two MachineInstrs into one. This routine does /// not modify the MIFlags of this MachineInstr. @@ -2088,7 +2110,7 @@ class MachineInstr void setExtraInfo(MachineFunction &MF, ArrayRef MMOs, MCSymbol *PreInstrSymbol, MCSymbol *PostInstrSymbol, MDNode *HeapAllocMarker, MDNode *PCSections, - uint32_t CFIType, MDNode *MMRAs); + uint32_t CFIType, MDNode *MMRAs, Value *DS); }; /// Special DenseMapInfo traits to compare MachineInstr* by *value* of the diff --git a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h index e705d7d99544c..caeb430d6fd1c 100644 --- a/llvm/include/llvm/CodeGen/MachineInstrBuilder.h +++ b/llvm/include/llvm/CodeGen/MachineInstrBuilder.h @@ -70,29 +70,44 @@ enum { } // end namespace RegState /// Set of metadata that should be preserved when using BuildMI(). This provides -/// a more convenient way of preserving DebugLoc, PCSections and MMRA. +/// a more convenient way of preserving certain data from the original +/// instruction. class MIMetadata { public: MIMetadata() = default; - MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr) - : DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA) {} + MIMetadata(DebugLoc DL, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr, + Value *DeactivationSymbol = nullptr) + : DL(std::move(DL)), PCSections(PCSections), MMRA(MMRA), + DeactivationSymbol(DeactivationSymbol) {} MIMetadata(const DILocation *DI, MDNode *PCSections = nullptr, MDNode *MMRA = nullptr) : DL(DI), PCSections(PCSections), MMRA(MMRA) {} explicit MIMetadata(const Instruction &From) : DL(From.getDebugLoc()), - PCSections(From.getMetadata(LLVMContext::MD_pcsections)) {} + PCSections(From.getMetadata(LLVMContext::MD_pcsections)), + DeactivationSymbol(getDeactivationSymbol(&From)) {} explicit MIMetadata(const MachineInstr &From) - : DL(From.getDebugLoc()), PCSections(From.getPCSections()) {} + : DL(From.getDebugLoc()), PCSections(From.getPCSections()), + DeactivationSymbol(From.getDeactivationSymbol()) {} const DebugLoc &getDL() const { return DL; } MDNode *getPCSections() const { return PCSections; } MDNode *getMMRAMetadata() const { return MMRA; } + Value *getDeactivationSymbol() const { return DeactivationSymbol; } private: DebugLoc DL; MDNode *PCSections = nullptr; MDNode *MMRA = nullptr; + Value *DeactivationSymbol = nullptr; + + static inline Value *getDeactivationSymbol(const Instruction *I) { + if (auto *CB = dyn_cast(I)) + if (auto Bundle = + CB->getOperandBundle(llvm::LLVMContext::OB_deactivation_symbol)) + return Bundle->Inputs[0].get(); + return nullptr; + } }; class MachineInstrBuilder { @@ -348,6 +363,8 @@ class MachineInstrBuilder { MI->setPCSections(*MF, MIMD.getPCSections()); if (MIMD.getMMRAMetadata()) MI->setMMRAMetadata(*MF, MIMD.getMMRAMetadata()); + if (MIMD.getDeactivationSymbol()) + MI->setDeactivationSymbol(*MF, MIMD.getDeactivationSymbol()); return *this; } diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h index b024e8a68bd6e..501cbc947132e 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -759,6 +759,7 @@ class SelectionDAG { int64_t offset = 0, unsigned TargetFlags = 0) { return getGlobalAddress(GV, DL, VT, offset, true, TargetFlags); } + LLVM_ABI SDValue getDeactivationSymbol(const GlobalValue *GV); LLVM_ABI SDValue getFrameIndex(int FI, EVT VT, bool isTarget = false); SDValue getTargetFrameIndex(int FI, EVT VT) { return getFrameIndex(FI, VT, true); diff --git a/llvm/include/llvm/CodeGen/SelectionDAGISel.h b/llvm/include/llvm/CodeGen/SelectionDAGISel.h index c5cdf76f4777e..7add717227963 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGISel.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGISel.h @@ -151,6 +151,7 @@ class SelectionDAGISel { OPC_RecordChild7, OPC_RecordMemRef, OPC_CaptureGlueInput, + OPC_CaptureDeactivationSymbol, OPC_MoveChild, OPC_MoveChild0, OPC_MoveChild1, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h index cfc8a4243e894..aa72e81b2ab54 100644 --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -2005,6 +2005,22 @@ class GlobalAddressSDNode : public SDNode { } }; +class DeactivationSymbolSDNode : public SDNode { + friend class SelectionDAG; + + const GlobalValue *TheGlobal; + + DeactivationSymbolSDNode(const GlobalValue *GV, SDVTList VTs) + : SDNode(ISD::DEACTIVATION_SYMBOL, 0, DebugLoc(), VTs), TheGlobal(GV) {} + +public: + const GlobalValue *getGlobal() const { return TheGlobal; } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::DEACTIVATION_SYMBOL; + } +}; + class FrameIndexSDNode : public SDNode { friend class SelectionDAG; diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 7df5d8a09f0f6..b2697c81fd825 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4765,6 +4765,7 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase { SmallVector InVals; const ConstantInt *CFIType = nullptr; SDValue ConvergenceControlToken; + GlobalValue *DeactivationSymbol = nullptr; std::optional PAI; @@ -4918,6 +4919,11 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase { return *this; } + CallLoweringInfo &setDeactivationSymbol(GlobalValue *Sym) { + DeactivationSymbol = Sym; + return *this; + } + ArgListTy &getArgs() { return Args; } diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h index 5972dcb637dfa..d938f4609742b 100644 --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -98,7 +98,8 @@ class LLVMContext { OB_kcfi = 8, // "kcfi" OB_convergencectrl = 9, // "convergencectrl" OB_align = 10, // "align" - OB_LastBundleID = OB_align // Marker for last bundle ID + OB_deactivation_symbol = 11, // "deactivation-symbol" + OB_LastBundleID = OB_deactivation_symbol }; /// getMDKindID - Return a unique non-zero ID for the specified metadata kind. diff --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td index 96a7d7c2091d2..54162dc6bb30f 100644 --- a/llvm/include/llvm/Target/Target.td +++ b/llvm/include/llvm/Target/Target.td @@ -694,6 +694,7 @@ class Instruction : InstructionEncoding { // If so, make sure to override // TargetInstrInfo::getInsertSubregLikeInputs. bit variadicOpsAreDefs = false; // Are variadic operands definitions? + bit supportsDeactivationSymbol = false; // Does the instruction have side effects that are not captured by any // operands of the instruction or other flags? diff --git a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp index 7be7468300569..e2ed45eec0ecd 100644 --- a/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CallLowering.cpp @@ -196,6 +196,10 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB, assert(Info.CFIType->getType()->isIntegerTy(32) && "Invalid CFI type"); } + if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + Info.DeactivationSymbol = cast(Bundle->Inputs[0]); + } + Info.CB = &CB; Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees); Info.CallConv = CallConv; diff --git a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp index 2ec138b6e186d..e0665d99a891d 100644 --- a/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp +++ b/llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -2917,6 +2917,9 @@ bool IRTranslator::translateIntrinsic( } } + if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) + MIB->setDeactivationSymbol(*MF, Bundle->Inputs[0].get()); + return true; } diff --git a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp index 637acd61c8a5f..3906b311addf0 100644 --- a/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp +++ b/llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp @@ -38,8 +38,10 @@ void MachineIRBuilder::setMF(MachineFunction &MF) { //------------------------------------------------------------------------------ MachineInstrBuilder MachineIRBuilder::buildInstrNoInsert(unsigned Opcode) { - return BuildMI(getMF(), {getDL(), getPCSections(), getMMRAMetadata()}, - getTII().get(Opcode)); + return BuildMI( + getMF(), + {getDL(), getPCSections(), getMMRAMetadata(), getDeactivationSymbol()}, + getTII().get(Opcode)); } MachineInstrBuilder MachineIRBuilder::insertInstr(MachineInstrBuilder MIB) { diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.cpp b/llvm/lib/CodeGen/MIRParser/MILexer.cpp index 8b72c295416a2..dbd56c7414f38 100644 --- a/llvm/lib/CodeGen/MIRParser/MILexer.cpp +++ b/llvm/lib/CodeGen/MIRParser/MILexer.cpp @@ -281,6 +281,7 @@ static MIToken::TokenKind getIdentifierKind(StringRef Identifier) { .Case("heap-alloc-marker", MIToken::kw_heap_alloc_marker) .Case("pcsections", MIToken::kw_pcsections) .Case("cfi-type", MIToken::kw_cfi_type) + .Case("deactivation-symbol", MIToken::kw_deactivation_symbol) .Case("bbsections", MIToken::kw_bbsections) .Case("bb_id", MIToken::kw_bb_id) .Case("unknown-size", MIToken::kw_unknown_size) diff --git a/llvm/lib/CodeGen/MIRParser/MILexer.h b/llvm/lib/CodeGen/MIRParser/MILexer.h index 0627f176b9e00..0407a0e7540d7 100644 --- a/llvm/lib/CodeGen/MIRParser/MILexer.h +++ b/llvm/lib/CodeGen/MIRParser/MILexer.h @@ -136,6 +136,7 @@ struct MIToken { kw_heap_alloc_marker, kw_pcsections, kw_cfi_type, + kw_deactivation_symbol, kw_bbsections, kw_bb_id, kw_unknown_size, diff --git a/llvm/lib/CodeGen/MIRParser/MIParser.cpp b/llvm/lib/CodeGen/MIRParser/MIParser.cpp index 434a579c3be3f..f35274d4e2edf 100644 --- a/llvm/lib/CodeGen/MIRParser/MIParser.cpp +++ b/llvm/lib/CodeGen/MIRParser/MIParser.cpp @@ -1072,6 +1072,7 @@ bool MIParser::parse(MachineInstr *&MI) { Token.isNot(MIToken::kw_heap_alloc_marker) && Token.isNot(MIToken::kw_pcsections) && Token.isNot(MIToken::kw_cfi_type) && + Token.isNot(MIToken::kw_deactivation_symbol) && Token.isNot(MIToken::kw_debug_location) && Token.isNot(MIToken::kw_debug_instr_number) && Token.isNot(MIToken::coloncolon) && Token.isNot(MIToken::lbrace)) { @@ -1120,6 +1121,14 @@ bool MIParser::parse(MachineInstr *&MI) { lex(); } + GlobalValue *DS = nullptr; + if (Token.is(MIToken::kw_deactivation_symbol)) { + lex(); + if (parseGlobalValue(DS)) + return true; + lex(); + } + unsigned InstrNum = 0; if (Token.is(MIToken::kw_debug_instr_number)) { lex(); @@ -1196,6 +1205,8 @@ bool MIParser::parse(MachineInstr *&MI) { MI->setPCSections(MF, PCSections); if (CFIType) MI->setCFIType(MF, CFIType); + if (DS) + MI->setDeactivationSymbol(MF, DS); if (!MemOperands.empty()) MI->setMemRefs(MF, MemOperands); if (InstrNum) diff --git a/llvm/lib/CodeGen/MIRPrinter.cpp b/llvm/lib/CodeGen/MIRPrinter.cpp index 1d54d72336860..c0554497653f8 100644 --- a/llvm/lib/CodeGen/MIRPrinter.cpp +++ b/llvm/lib/CodeGen/MIRPrinter.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/CodeGen/MIRFormatter.h" #include "llvm/CodeGen/MIRYamlMapping.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineConstantPool.h" @@ -895,6 +896,10 @@ static void printMI(raw_ostream &OS, MFPrintState &State, } if (uint32_t CFIType = MI.getCFIType()) OS << LS << "cfi-type " << CFIType; + if (Value *DS = MI.getDeactivationSymbol()) { + OS << LS << "deactivation-symbol "; + MIRFormatter::printIRValue(OS, *DS, State.MST); + } if (auto Num = MI.peekDebugInstrNum()) OS << LS << "debug-instr-number " << Num; diff --git a/llvm/lib/CodeGen/MachineFunction.cpp b/llvm/lib/CodeGen/MachineFunction.cpp index bfa5ab274c686..634547ded992f 100644 --- a/llvm/lib/CodeGen/MachineFunction.cpp +++ b/llvm/lib/CodeGen/MachineFunction.cpp @@ -609,10 +609,10 @@ MachineFunction::getMachineMemOperand(const MachineMemOperand *MMO, MachineInstr::ExtraInfo *MachineFunction::createMIExtraInfo( ArrayRef MMOs, MCSymbol *PreInstrSymbol, MCSymbol *PostInstrSymbol, MDNode *HeapAllocMarker, MDNode *PCSections, - uint32_t CFIType, MDNode *MMRAs) { + uint32_t CFIType, MDNode *MMRAs, Value *DS) { return MachineInstr::ExtraInfo::create(Allocator, MMOs, PreInstrSymbol, PostInstrSymbol, HeapAllocMarker, - PCSections, CFIType, MMRAs); + PCSections, CFIType, MMRAs, DS); } const char *MachineFunction::createExternalSymbolName(StringRef Name) { diff --git a/llvm/lib/CodeGen/MachineInstr.cpp b/llvm/lib/CodeGen/MachineInstr.cpp index eb46124d9eb5f..18111156efa4f 100644 --- a/llvm/lib/CodeGen/MachineInstr.cpp +++ b/llvm/lib/CodeGen/MachineInstr.cpp @@ -322,15 +322,17 @@ void MachineInstr::setExtraInfo(MachineFunction &MF, MCSymbol *PreInstrSymbol, MCSymbol *PostInstrSymbol, MDNode *HeapAllocMarker, MDNode *PCSections, - uint32_t CFIType, MDNode *MMRAs) { + uint32_t CFIType, MDNode *MMRAs, Value *DS) { bool HasPreInstrSymbol = PreInstrSymbol != nullptr; bool HasPostInstrSymbol = PostInstrSymbol != nullptr; bool HasHeapAllocMarker = HeapAllocMarker != nullptr; bool HasPCSections = PCSections != nullptr; bool HasCFIType = CFIType != 0; bool HasMMRAs = MMRAs != nullptr; + bool HasDS = DS != nullptr; int NumPointers = MMOs.size() + HasPreInstrSymbol + HasPostInstrSymbol + - HasHeapAllocMarker + HasPCSections + HasCFIType + HasMMRAs; + HasHeapAllocMarker + HasPCSections + HasCFIType + HasMMRAs + + HasDS; // Drop all extra info if there is none. if (NumPointers <= 0) { @@ -343,10 +345,10 @@ void MachineInstr::setExtraInfo(MachineFunction &MF, // 32-bit pointers. // FIXME: Maybe we should make the symbols in the extra info mutable? else if (NumPointers > 1 || HasMMRAs || HasHeapAllocMarker || HasPCSections || - HasCFIType) { + HasCFIType || HasDS) { Info.set( MF.createMIExtraInfo(MMOs, PreInstrSymbol, PostInstrSymbol, - HeapAllocMarker, PCSections, CFIType, MMRAs)); + HeapAllocMarker, PCSections, CFIType, MMRAs, DS)); return; } @@ -365,7 +367,7 @@ void MachineInstr::dropMemRefs(MachineFunction &MF) { setExtraInfo(MF, {}, getPreInstrSymbol(), getPostInstrSymbol(), getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setMemRefs(MachineFunction &MF, @@ -377,7 +379,7 @@ void MachineInstr::setMemRefs(MachineFunction &MF, setExtraInfo(MF, MMOs, getPreInstrSymbol(), getPostInstrSymbol(), getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::addMemOperand(MachineFunction &MF, @@ -488,7 +490,7 @@ void MachineInstr::setPreInstrSymbol(MachineFunction &MF, MCSymbol *Symbol) { setExtraInfo(MF, memoperands(), Symbol, getPostInstrSymbol(), getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setPostInstrSymbol(MachineFunction &MF, MCSymbol *Symbol) { @@ -504,7 +506,7 @@ void MachineInstr::setPostInstrSymbol(MachineFunction &MF, MCSymbol *Symbol) { setExtraInfo(MF, memoperands(), getPreInstrSymbol(), Symbol, getHeapAllocMarker(), getPCSections(), getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setHeapAllocMarker(MachineFunction &MF, MDNode *Marker) { @@ -513,7 +515,8 @@ void MachineInstr::setHeapAllocMarker(MachineFunction &MF, MDNode *Marker) { return; setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), - Marker, getPCSections(), getCFIType(), getMMRAMetadata()); + Marker, getPCSections(), getCFIType(), getMMRAMetadata(), + getDeactivationSymbol()); } void MachineInstr::setPCSections(MachineFunction &MF, MDNode *PCSections) { @@ -523,7 +526,7 @@ void MachineInstr::setPCSections(MachineFunction &MF, MDNode *PCSections) { setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), getHeapAllocMarker(), PCSections, getCFIType(), - getMMRAMetadata()); + getMMRAMetadata(), getDeactivationSymbol()); } void MachineInstr::setCFIType(MachineFunction &MF, uint32_t Type) { @@ -532,7 +535,8 @@ void MachineInstr::setCFIType(MachineFunction &MF, uint32_t Type) { return; setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), - getHeapAllocMarker(), getPCSections(), Type, getMMRAMetadata()); + getHeapAllocMarker(), getPCSections(), Type, getMMRAMetadata(), + getDeactivationSymbol()); } void MachineInstr::setMMRAMetadata(MachineFunction &MF, MDNode *MMRAs) { @@ -541,7 +545,18 @@ void MachineInstr::setMMRAMetadata(MachineFunction &MF, MDNode *MMRAs) { return; setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), - getHeapAllocMarker(), getPCSections(), getCFIType(), MMRAs); + getHeapAllocMarker(), getPCSections(), getCFIType(), MMRAs, + getDeactivationSymbol()); +} + +void MachineInstr::setDeactivationSymbol(MachineFunction &MF, Value *DS) { + // Do nothing if old and new symbols are the same. + if (DS == getDeactivationSymbol()) + return; + + setExtraInfo(MF, memoperands(), getPreInstrSymbol(), getPostInstrSymbol(), + getHeapAllocMarker(), getPCSections(), getCFIType(), + getMMRAMetadata(), DS); } void MachineInstr::cloneInstrSymbols(MachineFunction &MF, @@ -730,6 +745,8 @@ bool MachineInstr::isIdenticalTo(const MachineInstr &Other, // Call instructions with different CFI types are not identical. if (isCall() && getCFIType() != Other.getCFIType()) return false; + if (getDeactivationSymbol() != Other.getDeactivationSymbol()) + return false; return true; } @@ -2037,6 +2054,8 @@ void MachineInstr::print(raw_ostream &OS, ModuleSlotTracker &MST, OS << ','; OS << " cfi-type " << CFIType; } + if (getDeactivationSymbol()) + OS << ", deactivation-symbol " << getDeactivationSymbol()->getName(); if (DebugInstrNum) { if (!FirstOp) diff --git a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp index 52e8449fe510c..4ad721bf21959 100644 --- a/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp @@ -15,10 +15,12 @@ #include "InstrEmitter.h" #include "SDNodeDbgValue.h" #include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/StackMaps.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetLowering.h" @@ -61,6 +63,8 @@ static unsigned countOperands(SDNode *Node, unsigned NumExpUses, unsigned N = Node->getNumOperands(); while (N && Node->getOperand(N - 1).getValueType() == MVT::Glue) --N; + if (N && Node->getOperand(N - 1).getOpcode() == ISD::DEACTIVATION_SYMBOL) + --N; // Ignore deactivation symbol if it exists. if (N && Node->getOperand(N - 1).getValueType() == MVT::Other) --N; // Ignore chain if it exists. @@ -1222,15 +1226,23 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned, } } - if (SDNode *GluedNode = Node->getGluedNode()) { - // FIXME: Possibly iterate over multiple glue nodes? - if (GluedNode->getOpcode() == - ~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) { - Register VReg = getVR(GluedNode->getOperand(0), VRBaseMap); - MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false, - /*isImp=*/true); - MIB->addOperand(MO); - } + unsigned Op = Node->getNumOperands(); + if (Op != 0 && Node->getOperand(Op - 1)->getOpcode() == + ~(unsigned)TargetOpcode::CONVERGENCECTRL_GLUE) { + Register VReg = getVR(Node->getOperand(Op - 1)->getOperand(0), VRBaseMap); + MachineOperand MO = MachineOperand::CreateReg(VReg, /*isDef=*/false, + /*isImp=*/true); + MIB->addOperand(MO); + Op--; + } + + if (Op != 0 && + Node->getOperand(Op - 1)->getOpcode() == ISD::DEACTIVATION_SYMBOL) { + MI->setDeactivationSymbol( + *MF, const_cast( + cast(Node->getOperand(Op - 1)) + ->getGlobal())); + Op--; } // Run post-isel target hook to adjust this instruction if needed. @@ -1251,7 +1263,8 @@ EmitSpecialNode(SDNode *Node, bool IsClone, bool IsCloned, llvm_unreachable("This target-independent node should have been selected!"); case ISD::EntryToken: case ISD::MERGE_VALUES: - case ISD::TokenFactor: // fall thru + case ISD::TokenFactor: + case ISD::DEACTIVATION_SYMBOL: break; case ISD::CopyToReg: { Register DestReg = cast(Node->getOperand(1))->getReg(); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp index 1b15a207a2d37..06735708d5369 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1916,6 +1916,21 @@ SDValue SelectionDAG::getGlobalAddress(const GlobalValue *GV, const SDLoc &DL, return SDValue(N, 0); } +SDValue SelectionDAG::getDeactivationSymbol(const GlobalValue *GV) { + SDVTList VTs = getVTList(MVT::Untyped); + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::DEACTIVATION_SYMBOL, VTs, {}); + ID.AddPointer(GV); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, SDLoc(), IP)) + return SDValue(E, 0); + + auto *N = newSDNode(GV, VTs); + CSEMap.InsertNode(N, IP); + InsertNode(N); + return SDValue(N, 0); +} + SDValue SelectionDAG::getFrameIndex(int FI, EVT VT, bool isTarget) { unsigned Opc = isTarget ? ISD::TargetFrameIndex : ISD::FrameIndex; SDVTList VTs = getVTList(VT); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 88b35582a9f7d..53d73ad618bd1 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -45,6 +45,7 @@ #include "llvm/CodeGen/MachineOperand.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/SelectionDAGTargetInfo.h" #include "llvm/CodeGen/StackMaps.h" #include "llvm/CodeGen/SwiftErrorValueTracking.h" @@ -5376,6 +5377,14 @@ SmallVector SelectionDAGBuilder::getTargetIntrinsicOperands( } } + if (std::optional Bundle = + I.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + auto *Sym = Bundle->Inputs[0].get(); + SDValue SDSym = getValue(Sym); + SDSym = DAG.getDeactivationSymbol(cast(Sym)); + Ops.push_back(SDSym); + } + if (std::optional Bundle = I.getOperandBundle(LLVMContext::OB_convergencectrl)) { Value *Token = Bundle->Inputs[0].get(); @@ -9116,6 +9125,11 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee, ConvControlToken = getValue(Token); } + GlobalValue *DeactivationSymbol = nullptr; + if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_deactivation_symbol)) { + DeactivationSymbol = cast(Bundle->Inputs[0].get()); + } + TargetLowering::CallLoweringInfo CLI(DAG); CLI.setDebugLoc(getCurSDLoc()) .setChain(getRoot()) @@ -9125,7 +9139,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee, .setIsPreallocated( CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0) .setCFIType(CFIType) - .setConvergenceControlToken(ConvControlToken); + .setConvergenceControlToken(ConvControlToken) + .setDeactivationSymbol(DeactivationSymbol); // Set the pointer authentication info if we have it. if (PAI) { @@ -9745,7 +9760,7 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) { {LLVMContext::OB_deopt, LLVMContext::OB_funclet, LLVMContext::OB_cfguardtarget, LLVMContext::OB_preallocated, LLVMContext::OB_clang_arc_attachedcall, LLVMContext::OB_kcfi, - LLVMContext::OB_convergencectrl}); + LLVMContext::OB_convergencectrl, LLVMContext::OB_deactivation_symbol}); SDValue Callee = getValue(I.getCalledOperand()); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp index 0fad4722b1871..dd8f18d3b8a6a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -3308,6 +3308,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, case ISD::LIFETIME_START: case ISD::LIFETIME_END: case ISD::PSEUDO_PROBE: + case ISD::DEACTIVATION_SYMBOL: NodeToMatch->setNodeId(-1); // Mark selected. return; case ISD::AssertSext: @@ -3389,7 +3390,7 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, // These are the current input chain and glue for use when generating nodes. // Various Emit operations change these. For example, emitting a copytoreg // uses and updates these. - SDValue InputChain, InputGlue; + SDValue InputChain, InputGlue, DeactivationSymbol; // ChainNodesMatched - If a pattern matches nodes that have input/output // chains, the OPC_EmitMergeInputChains operation is emitted which indicates @@ -3542,6 +3543,15 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, InputGlue = N->getOperand(N->getNumOperands()-1); continue; + case OPC_CaptureDeactivationSymbol: + // If the current node has a deactivation symbol, capture it in + // DeactivationSymbol. + if (N->getNumOperands() != 0 && + N->getOperand(N->getNumOperands() - 1).getOpcode() == + ISD::DEACTIVATION_SYMBOL) + DeactivationSymbol = N->getOperand(N->getNumOperands() - 1); + continue; + case OPC_MoveChild: { unsigned ChildNo = MatcherTable[MatcherIndex++]; if (ChildNo >= N.getNumOperands()) @@ -4223,6 +4233,8 @@ void SelectionDAGISel::SelectCodeCommon(SDNode *NodeToMatch, // If this has chain/glue inputs, add them. if (EmitNodeInfo & OPFL_Chain) Ops.push_back(InputChain); + if (DeactivationSymbol.getNode() != nullptr) + Ops.push_back(DeactivationSymbol); if ((EmitNodeInfo & OPFL_GlueInput) && InputGlue.getNode() != nullptr) Ops.push_back(InputGlue); diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index cd39970f5111f..85d3690dd8306 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -620,7 +620,8 @@ bool CallBase::hasReadingOperandBundles() const { // ptrauth) forces a callsite to be at least readonly. return hasOperandBundlesOtherThan({LLVMContext::OB_ptrauth, LLVMContext::OB_kcfi, - LLVMContext::OB_convergencectrl}) && + LLVMContext::OB_convergencectrl, + LLVMContext::OB_deactivation_symbol}) && getIntrinsicID() != Intrinsic::assume; } @@ -628,7 +629,8 @@ bool CallBase::hasClobberingOperandBundles() const { return hasOperandBundlesOtherThan( {LLVMContext::OB_deopt, LLVMContext::OB_funclet, LLVMContext::OB_ptrauth, LLVMContext::OB_kcfi, - LLVMContext::OB_convergencectrl}) && + LLVMContext::OB_convergencectrl, + LLVMContext::OB_deactivation_symbol}) && getIntrinsicID() != Intrinsic::assume; } diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp index 335c210c10e1a..10aba759185a7 100644 --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -55,6 +55,8 @@ static StringRef knownBundleName(unsigned BundleTagID) { return "convergencectrl"; case LLVMContext::OB_align: return "align"; + case LLVMContext::OB_deactivation_symbol: + return "deactivation-symbol"; default: llvm_unreachable("unknown bundle id"); } diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index 3aa77bd47930f..0543cdc2e63d4 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -176,7 +176,12 @@ class AArch64AsmPrinter : public AsmPrinter { const MachineOperand *AUTAddrDisc, Register Scratch, std::optional PACKey, - uint64_t PACDisc, Register PACAddrDisc); + uint64_t PACDisc, Register PACAddrDisc, Value *DS); + + // Emit R_AARCH64_PATCHINST, the deactivation symbol relocation. Returns true + // if no instruction should be emitted because the deactivation symbol is + // defined in the current module so this function emitted a NOP instead. + bool emitDeactivationSymbolRelocation(Value *DS); // Emit the sequence for PAC. void emitPtrauthSign(const MachineInstr *MI); @@ -2113,11 +2118,31 @@ void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) { LRCheckMethod); } +bool AArch64AsmPrinter::emitDeactivationSymbolRelocation(Value *DS) { + if (!DS) + return false; + + if (isa(DS)) { + // Just emit the nop directly. + EmitToStreamer(MCInstBuilder(AArch64::HINT).addImm(0)); + return true; + } + MCSymbol *Dot = OutContext.createTempSymbol(); + OutStreamer->emitLabel(Dot); + const MCExpr *DeactDotExpr = MCSymbolRefExpr::create(Dot, OutContext); + + const MCExpr *DSExpr = MCSymbolRefExpr::create( + OutContext.getOrCreateSymbol(DS->getName()), OutContext); + OutStreamer->emitRelocDirective(*DeactDotExpr, "R_AARCH64_PATCHINST", DSExpr, + SMLoc()); + return false; +} + void AArch64AsmPrinter::emitPtrauthAuthResign( Register AUTVal, AArch64PACKey::ID AUTKey, uint64_t AUTDisc, const MachineOperand *AUTAddrDisc, Register Scratch, std::optional PACKey, uint64_t PACDisc, - Register PACAddrDisc) { + Register PACAddrDisc, Value *DS) { const bool IsAUTPAC = PACKey.has_value(); // We expand AUT/AUTPAC into a sequence of the form @@ -2164,15 +2189,17 @@ void AArch64AsmPrinter::emitPtrauthAuthResign( bool AUTZero = AUTDiscReg == AArch64::XZR; unsigned AUTOpc = getAUTOpcodeForKey(AUTKey, AUTZero); - // autiza x16 ; if AUTZero - // autia x16, x17 ; if !AUTZero - MCInst AUTInst; - AUTInst.setOpcode(AUTOpc); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - AUTInst.addOperand(MCOperand::createReg(AUTVal)); - if (!AUTZero) - AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); - EmitToStreamer(*OutStreamer, AUTInst); + if (!emitDeactivationSymbolRelocation(DS)) { + // autiza x16 ; if AUTZero + // autia x16, x17 ; if !AUTZero + MCInst AUTInst; + AUTInst.setOpcode(AUTOpc); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + AUTInst.addOperand(MCOperand::createReg(AUTVal)); + if (!AUTZero) + AUTInst.addOperand(MCOperand::createReg(AUTDiscReg)); + EmitToStreamer(*OutStreamer, AUTInst); + } // Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done. if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) @@ -2236,6 +2263,9 @@ void AArch64AsmPrinter::emitPtrauthSign(const MachineInstr *MI) { bool IsZeroDisc = DiscReg == AArch64::XZR; unsigned Opc = getPACOpcodeForKey(Key, IsZeroDisc); + if (emitDeactivationSymbolRelocation(MI->getDeactivationSymbol())) + return; + // paciza x16 ; if IsZeroDisc // pacia x16, x17 ; if !IsZeroDisc MCInst PACInst; @@ -3136,17 +3166,18 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { } case AArch64::AUTx16x17: - emitPtrauthAuthResign(AArch64::X16, - (AArch64PACKey::ID)MI->getOperand(0).getImm(), - MI->getOperand(1).getImm(), &MI->getOperand(2), - AArch64::X17, std::nullopt, 0, 0); + emitPtrauthAuthResign( + AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), + MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, + std::nullopt, 0, 0, MI->getDeactivationSymbol()); return; case AArch64::AUTxMxN: emitPtrauthAuthResign(MI->getOperand(0).getReg(), (AArch64PACKey::ID)MI->getOperand(3).getImm(), MI->getOperand(4).getImm(), &MI->getOperand(5), - MI->getOperand(1).getReg(), std::nullopt, 0, 0); + MI->getOperand(1).getReg(), std::nullopt, 0, 0, + MI->getDeactivationSymbol()); return; case AArch64::AUTPAC: @@ -3154,7 +3185,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, (AArch64PACKey::ID)MI->getOperand(3).getImm(), - MI->getOperand(4).getImm(), MI->getOperand(5).getReg()); + MI->getOperand(4).getImm(), MI->getOperand(5).getReg(), + MI->getDeactivationSymbol()); return; case AArch64::PAC: @@ -3635,6 +3667,9 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { return; } + if (emitDeactivationSymbolRelocation(MI->getDeactivationSymbol())) + return; + // Finally, do the automated lowerings for everything else. MCInst TmpInst; MCInstLowering.Lower(MI, TmpInst); diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp index 08466667c0fa5..b721c1f533726 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -1557,7 +1557,10 @@ void AArch64DAGToDAGISel::SelectPtrauthAuth(SDNode *N) { extractPtrauthBlendDiscriminators(AUTDisc, CurDAG); if (!Subtarget->isX16X17Safer()) { - SDValue Ops[] = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + std::vector Ops = {Val, AUTKey, AUTConstDisc, AUTAddrDisc}; + // Copy deactivation symbol if present. + if (N->getNumOperands() > 4) + Ops.push_back(N->getOperand(4)); SDNode *AUT = CurDAG->getMachineNode(AArch64::AUTxMxN, DL, MVT::i64, MVT::i64, Ops); diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index d4099b56b6d6e..dd70d729ffc91 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -10203,6 +10203,9 @@ AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI, if (InGlue.getNode()) Ops.push_back(InGlue); + if (CLI.DeactivationSymbol) + Ops.push_back(DAG.getDeactivationSymbol(CLI.DeactivationSymbol)); + // If we're doing a tall call, use a TC_RETURN here rather than an // actual call instruction. if (IsTailCall) { diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td index 6871c2d504cf6..61a8f764e39ed 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td +++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td @@ -2347,6 +2347,7 @@ class BImm pattern> let Inst{25-0} = addr; let DecoderMethod = "DecodeUnconditionalBranch"; + let supportsDeactivationSymbol = true; } class BranchImm pattern> @@ -2404,6 +2405,7 @@ class SignAuthOneData opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = Rn; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthZero opcode_prefix, bits<2> opcode, string asm, @@ -2417,6 +2419,7 @@ class SignAuthZero opcode_prefix, bits<2> opcode, string asm, let Inst{11-10} = opcode; let Inst{9-5} = 0b11111; let Inst{4-0} = Rd; + let supportsDeactivationSymbol = true; } class SignAuthTwoOperand opc, string asm, diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 03bad8ff8ac8a..b4d8649b31d6d 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -2215,6 +2215,7 @@ let Predicates = [HasPAuth] in { let Size = 12; let Defs = [X16, X17]; let usesCustomInserter = 1; + let supportsDeactivationSymbol = true; } // A standalone pattern is used, so that literal 0 can be passed as $Disc. diff --git a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp index 55694efafeed1..7907a3c283624 100644 --- a/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -1421,6 +1421,7 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, } else if (Info.CFIType) { MIB->setCFIType(MF, Info.CFIType->getZExtValue()); } + MIB->setDeactivationSymbol(MF, Info.DeactivationSymbol); MIB.add(Info.Callee); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 8e4edefec42fd..d903787f00c7f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -3077,6 +3077,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } case Intrinsic::ptrauth_auth: case Intrinsic::ptrauth_resign: { + // We don't support this optimization on intrinsic calls with deactivation + // symbols, which are represented using operand bundles. + if (II->hasOperandBundles()) + break; + // (sign|resign) + (auth|resign) can be folded by omitting the middle // sign+auth component if the key and discriminator match. bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign; @@ -3088,6 +3093,11 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // whatever we replace this sequence with. Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr; if (const auto *CI = dyn_cast(Ptr)) { + // We don't support this optimization on intrinsic calls with deactivation + // symbols, which are represented using operand bundles. + if (CI->hasOperandBundles()) + break; + BasePtr = CI->getArgOperand(0); if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) { if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc) diff --git a/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll b/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll index 5628e17b4936e..01e5b3f6673ae 100644 --- a/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll +++ b/llvm/test/Bitcode/operand-bundles-bc-analyzer.ll @@ -14,6 +14,7 @@ ; CHECK-NEXT: (N); + bool SupportsDeactivationSymbol = + EN->getInstruction().TheDef->getValueAsBit( + "supportsDeactivationSymbol"); + if (SupportsDeactivationSymbol) { + OS << "OPC_CaptureDeactivationSymbol,\n"; + OS.indent(FullIndexWidth + Indent); + } bool IsEmitNode = isa(EN); OS << (IsEmitNode ? "OPC_EmitNode" : "OPC_MorphNodeTo"); bool CompressVTs = EN->getNumVTs() < 3; @@ -1052,8 +1059,8 @@ unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N, OS << '\n'; } - return 4 + !CompressVTs + !CompressNodeInfo + NumTypeBytes + - NumOperandBytes + NumCoveredBytes; + return 4 + SupportsDeactivationSymbol + !CompressVTs + !CompressNodeInfo + + NumTypeBytes + NumOperandBytes + NumCoveredBytes; } case Matcher::CompleteMatch: { const CompleteMatchMatcher *CM = cast(N);