|
| 1 | +#ifndef TRITONINSTRUMENT_FUNCTIONBUILDER_H |
| 2 | +#define TRITONINSTRUMENT_FUNCTIONBUILDER_H |
| 3 | + |
| 4 | +#include "triton/Dialect/TritonInstrument/IR/Utility.h" |
| 5 | + |
| 6 | +#include <string> |
| 7 | +#include <variant> |
| 8 | + |
| 9 | +#include "llvm/ADT/ArrayRef.h" |
| 10 | +#include "llvm/ADT/StringRef.h" |
| 11 | + |
| 12 | +namespace mlir { |
| 13 | +class ImplicitLocOpBuilder; |
| 14 | +class ModuleOp; |
| 15 | +class Operation; |
| 16 | +class RankedTensorType; |
| 17 | +class Type; |
| 18 | +class Value; |
| 19 | +} // namespace mlir |
| 20 | + |
| 21 | +namespace mlir::triton { |
| 22 | +class FuncOp; |
| 23 | + |
| 24 | +namespace instrument { |
| 25 | + |
| 26 | +class ManglingArgs { |
| 27 | +public: |
| 28 | + using Arg = std::variant<Type, int, std::string>; |
| 29 | + |
| 30 | + ManglingArgs() = default; |
| 31 | + ManglingArgs(const ManglingArgs &) = default; |
| 32 | + ManglingArgs(ManglingArgs &&) = default; |
| 33 | + ManglingArgs &operator=(const ManglingArgs &) = default; |
| 34 | + ManglingArgs &operator=(ManglingArgs &&) = default; |
| 35 | + |
| 36 | + ManglingArgs(std::initializer_list<Arg> args) : args(args) {} |
| 37 | + |
| 38 | + ~ManglingArgs() = default; |
| 39 | + |
| 40 | + template <typename T> void append(T arg) { args.push_back(arg); } |
| 41 | + |
| 42 | + template <typename T> void append(ArrayRef<T> arg) { |
| 43 | + for (auto &a : arg) { |
| 44 | + args.push_back(a); |
| 45 | + } |
| 46 | + } |
| 47 | + |
| 48 | + void append(ManglingArgs &other) { |
| 49 | + args.append(other.args.begin(), other.args.end()); |
| 50 | + } |
| 51 | + |
| 52 | + std::string mangleArg(Arg arg) const { |
| 53 | + if (auto type = std::get_if<Type>(&arg)) { |
| 54 | + auto hash = static_cast<uint64_t>(mlir::hash_value(*type)); |
| 55 | + return std::string("_T") + llvm::utohexstr(hash); |
| 56 | + } else if (auto intVal = std::get_if<int>(&arg)) { |
| 57 | + return std::string("_I") + std::to_string(*intVal); |
| 58 | + } else if (auto stringVal = std::get_if<std::string>(&arg)) { |
| 59 | + return *stringVal; |
| 60 | + } |
| 61 | + llvm_unreachable("Unsupported argument type"); |
| 62 | + } |
| 63 | + |
| 64 | + std::string mangle(std::string baseName, int numWarps) const { |
| 65 | + std::string name = "__triton_consan_"; |
| 66 | + name += baseName; |
| 67 | + name += "_nw" + std::to_string(numWarps); |
| 68 | + for (auto arg : args) |
| 69 | + name += mangleArg(arg); |
| 70 | + return name; |
| 71 | + } |
| 72 | + |
| 73 | +private: |
| 74 | + SmallVector<Arg> args; |
| 75 | +}; |
| 76 | + |
| 77 | +/// Utility to mangle helper function names produced by the instrumentation |
| 78 | +/// passes. The mangled name encodes the base name, number of warps and the |
| 79 | +/// participating types. |
| 80 | +std::string mangleInstrumentHelperName(const std::string &baseName, |
| 81 | + int numWarps, |
| 82 | + llvm::ArrayRef<Type> types); |
| 83 | + |
| 84 | +class FunctionBuilder { |
| 85 | +public: |
| 86 | + FunctionBuilder(ModuleOp module, AuxDataMap &auxData) |
| 87 | + : module(module), auxData(auxData) {} |
| 88 | + |
| 89 | + // setWaiting: mark the base thread as waiting on the given barrier phase and |
| 90 | + // record that phase for deadlock detection. |
| 91 | + void createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread, |
| 92 | + Value phase, Value pred, Operation *insertPoint); |
| 93 | + // clearWaiting: clear the waiting flag and stored phase for the base thread. |
| 94 | + void createClearWaitingCall(ImplicitLocOpBuilder &b, Value mbar, int thread, |
| 95 | + Value pred, Operation *insertPoint); |
| 96 | + // checkAllActiveWaiting: assert that not all active threads are waiting on |
| 97 | + // matching barrier phases. |
| 98 | + void createCheckAllActiveWaitingCall(ImplicitLocOpBuilder &b, int activeMask, |
| 99 | + Value pred, Operation *insertPoint); |
| 100 | + // initBarrierState: Initialize the tracked barrier state to phase 0 and set |
| 101 | + // both the initial and current arrival counts. |
| 102 | + void createInitBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar, |
| 103 | + int count, Operation *insertPoint); |
| 104 | + // verifyBarrierArrive: Check that applying the arrive count would not drive |
| 105 | + // the tracked current count negative. Triggers an assertion on failure. |
| 106 | + void createVerifyBarrierArriveCall(ImplicitLocOpBuilder &b, Value mbar, |
| 107 | + int count, Value pred, |
| 108 | + Operation *insertPoint); |
| 109 | + // updateBarrierState: Apply an arrive count to the tracked barrier state, |
| 110 | + // toggling the phase when the count reaches zero and reloading the current |
| 111 | + // count from the initial count. |
| 112 | + void createUpdateBarrierStateCall(ImplicitLocOpBuilder &b, Value mbar, |
| 113 | + int count, Value pred, |
| 114 | + Operation *insertPoint); |
| 115 | + // setWriteVisibility: Set the write visibility for a buffer. Marks the buffer |
| 116 | + // as visible to the threads set in threadMask. Clears out any other threads |
| 117 | + // from the visibility bitmask. We know this is safe because there cannot be |
| 118 | + // outstanding writes to this buffer at this point. |
| 119 | + void createSetWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, |
| 120 | + uint64_t threadMask, Value pred, |
| 121 | + MemType memType, Operation *insertPoint); |
| 122 | + // setReadVisibility: add the threads set in threadMask to the buffer's read |
| 123 | + // visibility bitmask. |
| 124 | + void createSetReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, |
| 125 | + uint64_t threadMask, Value pred, |
| 126 | + MemType memType, Operation *insertPoint); |
| 127 | + // clearWriteTracking: clear all the information about threads writing to a |
| 128 | + // buffer. |
| 129 | + void createClearWriteTrackingCall(ImplicitLocOpBuilder &b, Value buf, |
| 130 | + Value pred, MemType memType, |
| 131 | + Operation *insertPoint); |
| 132 | + // clearReadVisibility: clear the read visibility for a buffer. |
| 133 | + void createClearReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, |
| 134 | + Value pred, MemType memType, |
| 135 | + Operation *insertPoint); |
| 136 | + // clearReadTracking: clear the read tracking for a buffer. |
| 137 | + void createClearReadTrackingCall(ImplicitLocOpBuilder &b, Value buf, |
| 138 | + Value pred, MemType memType, |
| 139 | + Operation *insertPoint); |
| 140 | + // trackVisibleWrites: snapshot buffers currently visible to the thread into |
| 141 | + // the tracking table for a barrier. |
| 142 | + void createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar, |
| 143 | + int thread, Value pred, MemType memType, |
| 144 | + Operation *insertPoint); |
| 145 | + // trackVisibleReads: snapshot buffers currently visible to the thread into |
| 146 | + // the read tracking table for a barrier. |
| 147 | + void createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, |
| 148 | + int thread, Value pred, MemType memType, |
| 149 | + Operation *insertPoint); |
| 150 | + // transferVisibleWrites: transfer write visibility tracked by a barrier to |
| 151 | + // all threads in threadMask. |
| 152 | + void createTransferVisibleWritesCall(ImplicitLocOpBuilder &b, Value mbar, |
| 153 | + uint64_t threadMask, Value pred, |
| 154 | + MemType memType, Operation *insertPoint); |
| 155 | + // transferVisibleReads: transfer read visibility tracked by a barrier to all |
| 156 | + // threads in threadMask. |
| 157 | + void createTransferVisibleReadsCall(ImplicitLocOpBuilder &b, Value mbar, |
| 158 | + uint64_t threadMask, Value pred, |
| 159 | + MemType memType, Operation *insertPoint); |
| 160 | + // verifyWriteVisibility: ensure the thread either sees the latest write or no |
| 161 | + // other thread is writing the buffer. |
| 162 | + void createVerifyWriteVisibilityCall(ImplicitLocOpBuilder &b, Value buf, |
| 163 | + int thread, StringRef operandName, |
| 164 | + Value pred, MemType memType, |
| 165 | + Operation *insertPoint); |
| 166 | + // verifyReadVisibility: ensure all reads from the buffer are visible to the |
| 167 | + // thread. |
| 168 | + void createVerifyReadVisibilityCall(ImplicitLocOpBuilder &b, Value buf, |
| 169 | + int thread, StringRef operandName, |
| 170 | + Value pred, MemType memType, |
| 171 | + Operation *insertPoint); |
| 172 | + // copyWriteVisibility: replicate the write visibility bit of sourceThread to |
| 173 | + // every destination thread in destMask. |
| 174 | + void createCopyWriteVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread, |
| 175 | + uint64_t destMask, Value pred, |
| 176 | + MemType memType, Operation *insertPoint); |
| 177 | + // copyReadVisibility: replicate the read visibility row of sourceThread to |
| 178 | + // every destination thread in destMask. |
| 179 | + void createCopyReadVisibilityCall(ImplicitLocOpBuilder &b, int sourceThread, |
| 180 | + uint64_t destMask, Value pred, |
| 181 | + MemType memType, Operation *insertPoint); |
| 182 | + // stageAccessForCommit: mark the buffer as staged (value -1) in the |
| 183 | + // outstanding commit table for this thread. |
| 184 | + void createStageAccessForCommitCall(ImplicitLocOpBuilder &b, Value buf, |
| 185 | + int thread, Value pred, ValueType buffers, |
| 186 | + ValueType outstandingCommits, |
| 187 | + Operation *insertPoint); |
| 188 | + // commitAccesses: convert staged entries to 1 and increment outstanding |
| 189 | + // commits greater than zero for the committing thread. |
| 190 | + void createCommitAccessesCall(ImplicitLocOpBuilder &b, int thread, Value pred, |
| 191 | + ValueType outstandingCommits, |
| 192 | + Operation *insertPoint); |
| 193 | + // clearOutstandingCommitsTransferWrites: clear entries farther than |
| 194 | + // outstandingNum from the thread and set write visibility for threads in |
| 195 | + // transferThreadMask. |
| 196 | + void createClearOutstandingCommitsTransferWritesCall( |
| 197 | + ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask, |
| 198 | + int outstandingNum, Value pred, ValueType outstandingCommits, |
| 199 | + ValueType writeVisibility, Operation *insertPoint); |
| 200 | + // clearOutstandingCommitsTransferReads: clear entries farther than |
| 201 | + // outstandingNum from the thread and set read visibility for threads in |
| 202 | + // transferThreadMask. |
| 203 | + void createClearOutstandingCommitsTransferReadsCall( |
| 204 | + ImplicitLocOpBuilder &b, int thread, uint64_t transferThreadMask, |
| 205 | + int outstandingNum, Value pred, ValueType outstandingCommits, |
| 206 | + ValueType readVisibility, Operation *insertPoint); |
| 207 | + // checkOutstandingCommits: assert that the outstanding commit row for the |
| 208 | + // buffer is zero before the access described by pendingAccessType. |
| 209 | + void createCheckOutstandingCommitsCall(ImplicitLocOpBuilder &b, Value buf, |
| 210 | + int thread, |
| 211 | + StringRef pendingAccessType, |
| 212 | + Value pred, ValueType buffers, |
| 213 | + ValueType outstandingCommits, |
| 214 | + Operation *insertPoint); |
| 215 | + |
| 216 | +private: |
| 217 | + ModuleOp module; |
| 218 | + AuxDataMap &auxData; |
| 219 | +}; |
| 220 | + |
| 221 | +} // namespace instrument |
| 222 | +} // namespace mlir::triton |
| 223 | + |
| 224 | +#endif |
0 commit comments