4242#include " SPIRVInternal.h"
4343#include " libSPIRV/SPIRVDebug.h"
4444
45+ #include " llvm/ADT/SmallPtrSet.h"
46+ #include " llvm/ADT/SmallVector.h"
4547#include " llvm/ADT/StringSwitch.h"
4648#include " llvm/Analysis/ValueTracking.h"
49+ #include " llvm/IR/Constants.h"
4750#include " llvm/IR/IRBuilder.h"
4851#include " llvm/IR/Instruction.h"
4952#include " llvm/IR/Instructions.h"
53+ #include " llvm/IR/Operator.h"
5054#include " llvm/IR/PatternMatch.h"
55+ #include " llvm/IR/TypedPointerType.h"
5156#include " llvm/Support/Debug.h"
5257
5358#include < algorithm>
59+ #include < optional>
5460#include < regex>
5561#include < set>
5662
@@ -62,6 +68,88 @@ using namespace SPIRV;
6268using namespace OCLUtil ;
6369
6470namespace SPIRV {
71+
72+ static unsigned getAddressSpaceFromType (const Type *Ty) {
73+ assert (Ty && " Can't deduce pointer AS" );
74+ if (auto *TypedPtr = dyn_cast<TypedPointerType>(Ty))
75+ return TypedPtr->getAddressSpace ();
76+ if (auto *Ptr = dyn_cast<PointerType>(Ty))
77+ return Ptr->getAddressSpace ();
78+ llvm_unreachable (" Can't deduce pointer AS" );
79+ }
80+
81+ // Performs an address space inference analysis.
82+ static unsigned getAddressSpaceFromValue (const Value *Ptr) {
83+ assert (Ptr && " Can't deduce pointer AS" );
84+
85+ SmallPtrSet<const Value *, 8 > Visited;
86+ SmallVector<const Value *, 8 > Worklist;
87+ Worklist.push_back (Ptr);
88+ unsigned AS = SPIRAS_Generic;
89+
90+ while (!Worklist.empty ()) {
91+ const Value *Current = Worklist.pop_back_val ();
92+ if (!Visited.insert (Current).second )
93+ continue ;
94+
95+ unsigned DeducedAS = getAddressSpaceFromType (Current->getType ());
96+ if (DeducedAS != SPIRAS_Generic)
97+ return DeducedAS;
98+ AS = DeducedAS;
99+
100+ // Find origins of the pointer and add to the worklist.
101+ if (auto *Op = dyn_cast<Operator>(Current)) {
102+ switch (Op->getOpcode ()) {
103+ case Instruction::AddrSpaceCast:
104+ case Instruction::BitCast:
105+ case Instruction::GetElementPtr:
106+ Worklist.push_back (Op->getOperand (0 ));
107+ break ;
108+ case Instruction::Select:
109+ Worklist.push_back (Op->getOperand (1 ));
110+ Worklist.push_back (Op->getOperand (2 ));
111+ break ;
112+ case Instruction::PHI: {
113+ auto *Phi = cast<PHINode>(Op);
114+ for (Value *Incoming : Phi->incoming_values ())
115+ Worklist.push_back (Incoming);
116+ break ;
117+ }
118+ default :
119+ break ;
120+ }
121+ }
122+ }
123+
124+ return AS;
125+ }
126+
127+ // Sets memory semantic mask of an atomic depending on a pointer argument
128+ // address space.
129+ static unsigned
130+ getAtomicPointerMemorySemanticsMemoryMask (const Value *Ptr,
131+ const Type *RecordedType) {
132+ assert ((Ptr && RecordedType) &&
133+ " Can't evaluate atomic builtin's memory semantic" );
134+ unsigned AddrSpace = getAddressSpaceFromType (RecordedType);
135+ if (AddrSpace == SPIRAS_Generic)
136+ AddrSpace = getAddressSpaceFromValue (Ptr);
137+
138+ switch (AddrSpace) {
139+ case SPIRAS_Global:
140+ case SPIRAS_GlobalDevice:
141+ case SPIRAS_GlobalHost:
142+ return MemorySemanticsCrossWorkgroupMemoryMask;
143+ case SPIRAS_Local:
144+ return MemorySemanticsWorkgroupMemoryMask;
145+ case SPIRAS_Generic:
146+ return MemorySemanticsCrossWorkgroupMemoryMask |
147+ MemorySemanticsWorkgroupMemoryMask;
148+ default :
149+ return MemorySemanticsMaskNone;
150+ }
151+ }
152+
65153static size_t getOCLCpp11AtomicMaxNumOps (StringRef Name) {
66154 return StringSwitch<size_t >(Name)
67155 .Cases ({" load" , " flag_test_and_set" , " flag_clear" }, 3 )
@@ -700,6 +788,11 @@ void OCLToSPIRVBase::transAtomicBuiltin(CallInst *CI,
700788 const size_t ScopeIdx = ArgsCount - 1 ;
701789 const size_t OrderIdx = ScopeIdx - NumOrder;
702790
791+ unsigned PtrMemSemantics = MemorySemanticsMaskNone;
792+ if (Mutator.arg_size () > 0 )
793+ PtrMemSemantics = getAtomicPointerMemorySemanticsMemoryMask (
794+ Mutator.getArg (0 ), Mutator.getType (0 ));
795+
703796 if (NeedsNegate) {
704797 Mutator.mapArg (1 , [=](Value *V) {
705798 IRBuilder<> IRB (CI);
@@ -710,9 +803,20 @@ void OCLToSPIRVBase::transAtomicBuiltin(CallInst *CI,
710803 return transOCLMemScopeIntoSPIRVScope (V, OCLMS_device, CI);
711804 });
712805 for (size_t I = 0 ; I < NumOrder; ++I) {
713- Mutator.mapArg (OrderIdx + I, [=](Value *V) {
714- return transOCLMemOrderIntoSPIRVMemorySemantics (V, OCLMO_seq_cst, CI);
715- });
806+ Mutator.mapArg (
807+ OrderIdx + I, [=](IRBuilder<> &Builder, Value *V) -> Value * {
808+ Value *MemSem =
809+ transOCLMemOrderIntoSPIRVMemorySemantics (V, OCLMO_seq_cst, CI);
810+ if (PtrMemSemantics == MemorySemanticsMaskNone)
811+ return MemSem;
812+
813+ auto *MemSemTy = cast<IntegerType>(MemSem->getType ());
814+ auto *Mask = ConstantInt::get (MemSemTy, PtrMemSemantics);
815+ if (auto *Const = dyn_cast<ConstantInt>(MemSem))
816+ return static_cast <Value *>(ConstantInt::get (
817+ MemSemTy, Const->getZExtValue () | PtrMemSemantics));
818+ return Builder.CreateOr (MemSem, Mask);
819+ });
716820 }
717821
718822 // Order of args in SPIR-V:
0 commit comments