Skip to content

Commit c72740f

Browse files
authored
Merge OpenAI Triton commit b3cf593 (#5487)
This PR changes the Triton base from c186592 to b3cf593 (Oct 30). Pass rate: 94.95%
2 parents 6f1525f + 4862fe6 commit c72740f

File tree

51 files changed

+2742
-3273
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+2742
-3273
lines changed

include/triton/Analysis/Utility.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -268,19 +268,6 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
268268
// ConvertLayoutOpHelper in the future
269269
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
270270

271-
/// Multi-root DAG topological sort.
272-
/// Performs a topological sort of the Operation in the `toSort` SetVector.
273-
/// Returns a topologically sorted SetVector.
274-
/// It is faster than mlir::topologicalSort because it prunes nodes that have
275-
/// been visited before.
276-
SetVector<Operation *>
277-
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
278-
279-
/// This uses the toplogicalSort above
280-
SetVector<Operation *>
281-
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
282-
TransitiveFilter forwardFilter = nullptr);
283-
284271
/// Create a basic DataFlowSolver with constant and dead code analysis included.
285272
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
286273

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
#ifndef TRITONINSTRUMENT_FUNCTIONBUILDER_H
2+
#define TRITONINSTRUMENT_FUNCTIONBUILDER_H
3+
4+
#include "triton/Dialect/TritonInstrument/IR/Utility.h"
5+
6+
#include <string>
7+
#include <variant>
8+
9+
#include "llvm/ADT/ArrayRef.h"
10+
#include "llvm/ADT/StringRef.h"
11+
12+
namespace mlir {
13+
class ImplicitLocOpBuilder;
14+
class ModuleOp;
15+
class Operation;
16+
class RankedTensorType;
17+
class Type;
18+
class Value;
19+
} // namespace mlir
20+
21+
namespace mlir::triton {
22+
class FuncOp;
23+
24+
namespace instrument {
25+
26+
class ManglingArgs {
27+
public:
28+
using Arg = std::variant<Type, int, std::string>;
29+
30+
ManglingArgs() = default;
31+
ManglingArgs(const ManglingArgs &) = default;
32+
ManglingArgs(ManglingArgs &&) = default;
33+
ManglingArgs &operator=(const ManglingArgs &) = default;
34+
ManglingArgs &operator=(ManglingArgs &&) = default;
35+
36+
ManglingArgs(std::initializer_list<Arg> args) : args(args) {}
37+
38+
~ManglingArgs() = default;
39+
40+
template <typename T> void append(T arg) { args.push_back(arg); }
41+
42+
template <typename T> void append(ArrayRef<T> arg) {
43+
for (auto &a : arg) {
44+
args.push_back(a);
45+
}
46+
}
47+
48+
void append(ManglingArgs &other) {
49+
args.append(other.args.begin(), other.args.end());
50+
}
51+
52+
std::string mangleArg(Arg arg) const {
53+
if (auto type = std::get_if<Type>(&arg)) {
54+
auto hash = static_cast<uint64_t>(mlir::hash_value(*type));
55+
return std::string("_T") + llvm::utohexstr(hash);
56+
} else if (auto intVal = std::get_if<int>(&arg)) {
57+
return std::string("_I") + std::to_string(*intVal);
58+
} else if (auto stringVal = std::get_if<std::string>(&arg)) {
59+
return *stringVal;
60+
}
61+
llvm_unreachable("Unsupported argument type");
62+
}
63+
64+
std::string mangle(std::string baseName, int numWarps) const {
65+
std::string name = "__triton_consan_";
66+
name += baseName;
67+
name += "_nw" + std::to_string(numWarps);
68+
for (auto arg : args)
69+
name += mangleArg(arg);
70+
return name;
71+
}
72+
73+
private:
74+
SmallVector<Arg> args;
75+
};
76+
77+
/// Utility to mangle helper function names produced by the instrumentation
78+
/// passes. The mangled name encodes the base name, number of warps and the
79+
/// participating types.
80+
std::string mangleInstrumentHelperName(const std::string &baseName,
81+
int numWarps,
82+
llvm::ArrayRef<Type> types);
83+
84+
class FunctionBuilder {
85+
public:
86+
FunctionBuilder(ModuleOp module, AuxDataMap &auxData)
87+
: module(module), auxData(auxData) {}
88+
89+
// setWaiting: mark the base thread as waiting on the given barrier phase and
90+
// record that phase for deadlock detection.
91+
void createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread,
92+
Value phase, Value pred, Operation *insertPoint);
93+
// clearWaiting: clear the waiting flag and stored phase for the base thread.
94+
void createClearWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread,
95+
Value pred, Operation *insertPoint);
96+
// checkAllActiveWaiting: assert that not all active threads are waiting on
97+
// matching barrier phases.
98+
void createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, int activeMask,
99+
Value pred, Operation *insertPoint);
100+
// initBarrierState: Initialize the tracked barrier state to phase 0 and set
101+
// both the initial and current arrival counts.
102+
void createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar,
103+
int count, Operation *insertPoint);
104+
// verifyBarrierArrive: Check that applying the arrive count would not drive
105+
// the tracked current count negative. Triggers an assertion on failure.
106+
void createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value mbar,
107+
int count, Value pred,
108+
Operation *insertPoint);
109+
// updateBarrierState: Apply an arrive count to the tracked barrier state,
110+
// toggling the phase when the count reaches zero and reloading the current
111+
// count from the initial count.
112+
void createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar,
113+
int count, Value pred,
114+
Operation *insertPoint);
115+
// setWriteVisibility: Set the write visibility for a buffer. Marks the buffer
116+
// as visible to the threads set in threadMask. Clears out any other threads
117+
// from the visibility bitmask. We know this is safe because there cannot be
118+
// outstanding writes to this buffer at this point.
119+
void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
120+
uint64_t threadMask, Value pred,
121+
MemType memType, Operation *insertPoint);
122+
// setReadVisibility: add the threads set in threadMask to the buffer's read
123+
// visibility bitmask.
124+
void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
125+
uint64_t threadMask, Value pred,
126+
MemType memType, Operation *insertPoint);
127+
// clearWriteTracking: clear all the information about threads writing to a
128+
// buffer.
129+
void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf,
130+
Value pred, MemType memType,
131+
Operation *insertPoint);
132+
// clearReadVisibility: clear the read visibility for a buffer.
133+
void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
134+
Value pred, MemType memType,
135+
Operation *insertPoint);
136+
// clearReadTracking: clear the read tracking for a buffer.
137+
void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf,
138+
Value pred, MemType memType,
139+
Operation *insertPoint);
140+
// trackVisibleWrites: snapshot buffers currently visible to the thread into
141+
// the tracking table for a barrier.
142+
void createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar,
143+
int thread, Value pred, MemType memType,
144+
Operation *insertPoint);
145+
// trackVisibleReads: snapshot buffers currently visible to the thread into
146+
// the read tracking table for a barrier.
147+
void createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar,
148+
int thread, Value pred, MemType memType,
149+
Operation *insertPoint);
150+
// transferVisibleWrites: transfer write visibility tracked by a barrier to
151+
// all threads in threadMask.
152+
void createTransferVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar,
153+
uint64_t threadMask, Value pred,
154+
MemType memType, Operation *insertPoint);
155+
// transferVisibleReads: transfer read visibility tracked by a barrier to all
156+
// threads in threadMask.
157+
void createTransferVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar,
158+
uint64_t threadMask, Value pred,
159+
MemType memType, Operation *insertPoint);
160+
// verifyWriteVisibility: ensure the thread either sees the latest write or no
161+
// other thread is writing the buffer.
162+
void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
163+
int thread, StringRef operandName,
164+
Value pred, MemType memType,
165+
Operation *insertPoint);
166+
// verifyReadVisibility: ensure all reads from the buffer are visible to the
167+
// thread.
168+
void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf,
169+
int thread, StringRef operandName,
170+
Value pred, MemType memType,
171+
Operation *insertPoint);
172+
// copyWriteVisibility: replicate the write visibility bit of sourceThread to
173+
// every destination thread in destMask.
174+
void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
175+
uint64_t destMask, Value pred,
176+
MemType memType, Operation *insertPoint);
177+
// copyReadVisibility: replicate the read visibility row of sourceThread to
178+
// every destination thread in destMask.
179+
void createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread,
180+
uint64_t destMask, Value pred,
181+
MemType memType, Operation *insertPoint);
182+
// stageAccessForCommit: mark the buffer as staged (value -1) in the
183+
// outstanding commit table for this thread.
184+
void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf,
185+
int thread, Value pred, ValueType buffers,
186+
ValueType outstandingCommits,
187+
Operation *insertPoint);
188+
// commitAccesses: convert staged entries to 1 and increment outstanding
189+
// commits greater than zero for the committing thread.
190+
void createCommitAccessesCall(ImplicitLocOpBuilder &b, int thread, Value pred,
191+
ValueType outstandingCommits,
192+
Operation *insertPoint);
193+
// clearOutstandingCommitsTransferWrites: clear entries farther than
194+
// outstandingNum from the thread and set write visibility for threads in
195+
// transferThreadMask.
196+
void createClearOutstandingCommitsTransferWritesCall(
197+
ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask,
198+
int outstandingNum, Value pred, ValueType outstandingCommits,
199+
ValueType writeVisibility, Operation *insertPoint);
200+
// clearOutstandingCommitsTransferReads: clear entries farther than
201+
// outstandingNum from the thread and set read visibility for threads in
202+
// transferThreadMask.
203+
void createClearOutstandingCommitsTransferReadsCall(
204+
ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask,
205+
int outstandingNum, Value pred, ValueType outstandingCommits,
206+
ValueType readVisibility, Operation *insertPoint);
207+
// checkOutstandingCommits: assert that the outstanding commit row for the
208+
// buffer is zero before the access described by pendingAccessType.
209+
void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf,
210+
int thread,
211+
StringRef pendingAccessType,
212+
Value pred, ValueType buffers,
213+
ValueType outstandingCommits,
214+
Operation *insertPoint);
215+
216+
private:
217+
ModuleOp module;
218+
AuxDataMap &auxData;
219+
};
220+
221+
} // namespace instrument
222+
} // namespace mlir::triton
223+
224+
#endif

0 commit comments

Comments
 (0)