Skip to content

Commit cdde15b

Browse files
authored
[TRTLLM-8540][feat] Add support for disagg in DSv3.2 (#8735)
Signed-off-by: Iman Tabrizian <10105175+tabrizian@users.noreply.github.com>
1 parent 264d38e commit cdde15b

File tree

24 files changed

+1131
-588
lines changed

24 files changed

+1131
-588
lines changed

cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ class CacheTransceiver : public BaseCacheTransceiver
269269
std::unique_ptr<executor::kv_cache::CacheState> mCacheState;
270270
std::unique_ptr<executor::kv_cache::ConnectionManager> mManager;
271271
std::optional<executor::CacheTransceiverConfig> mCacheTransceiverConfig;
272-
std::unique_ptr<kv_cache_manager::CacheTransBufferManager> mCacheTransBufferManager;
272+
std::vector<std::unique_ptr<kv_cache_manager::CacheTransBufferManager>> mCacheTransBufferManagers;
273+
std::vector<kv_cache_manager::CacheTransBufferManager*> mCacheTransBufferManagerPtrs;
273274
// library handle to the communicator related features,
274275
// this is used to defer dependency resolution until needed.
275276
static std::mutex mDllMutex;

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,21 @@ class WindowBlockManager
595595

596596
~WindowBlockManager();
597597

598+
[[nodiscard]] bool isEnableIndexerKCache() const
599+
{
600+
return mEnableIndexerKCache;
601+
}
602+
603+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const
604+
{
605+
return mIndexerKCacheQuantBlockSize;
606+
}
607+
608+
[[nodiscard]] SizeType32 getIndexerKCacheIndexHeadDim() const
609+
{
610+
return mIndexerKCacheIndexHeadDim;
611+
}
612+
598613
void allocatePools(bool useUvm);
599614

600615
void releasePools();
@@ -1021,6 +1036,21 @@ class BlockManager
10211036
std::optional<kvc::BaseAgentConfig> agentConfig = std::nullopt, bool enableIndexerKCache = false,
10221037
SizeType32 indexerKCacheQuantBlockSize = 128, SizeType32 indexerKCacheIndexHeadDim = 0);
10231038

1039+
[[nodiscard]] bool isEnableIndexerKCache() const
1040+
{
1041+
return mIsEnableIndexerKCache;
1042+
}
1043+
1044+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const
1045+
{
1046+
return mIndexerKCacheQuantBlockSize;
1047+
}
1048+
1049+
[[nodiscard]] SizeType32 getIndexerKCacheIndexHeadDim() const
1050+
{
1051+
return mIndexerKCacheIndexHeadDim;
1052+
}
1053+
10241054
BlockManager(BlockManager const&) = delete;
10251055
BlockManager& operator=(BlockManager const&) = delete;
10261056

@@ -1398,6 +1428,10 @@ class BlockManager
13981428
std::vector<SizeType32> mAbsolutePoolToRelativePoolIndex;
13991429
// Record what sequences are currently managed by the block manager
14001430
std::set<LlmRequest::RequestIdType> mManagedSequences;
1431+
1432+
bool mIsEnableIndexerKCache{false};
1433+
SizeType32 mIndexerKCacheQuantBlockSize{0};
1434+
SizeType32 mIndexerKCacheIndexHeadDim{0};
14011435
};
14021436

14031437
struct OffsetTableDimensions
@@ -1500,6 +1534,10 @@ class BaseKVCacheManager
15001534

15011535
[[nodiscard]] virtual bool isEnableBlockReuse() const = 0;
15021536

1537+
[[nodiscard]] virtual bool isEnableIndexerKCache() const = 0;
1538+
[[nodiscard]] virtual SizeType32 getIndexerKCacheIndexHeadDim() const = 0;
1539+
[[nodiscard]] virtual SizeType32 getIndexerKCacheQuantBlockSize() const = 0;
1540+
15031541
// void removeToken(SizeType32 seqSlotIdx);
15041542
virtual void rewindKVCache(LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) = 0;
15051543

@@ -1834,6 +1872,21 @@ class KVCacheManager : public BaseKVCacheManager
18341872
return mEnableBlockReuse;
18351873
}
18361874

1875+
[[nodiscard]] bool isEnableIndexerKCache() const override
1876+
{
1877+
return mBlockManager.isEnableIndexerKCache();
1878+
}
1879+
1880+
[[nodiscard]] SizeType32 getIndexerKCacheIndexHeadDim() const override
1881+
{
1882+
return mBlockManager.getIndexerKCacheIndexHeadDim();
1883+
}
1884+
1885+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const override
1886+
{
1887+
return mBlockManager.getIndexerKCacheQuantBlockSize();
1888+
}
1889+
18371890
void removeToken(LlmRequest::RequestIdType requestId);
18381891
void rewindKVCache(LlmRequest::RequestIdType requestId, SizeType32 rewindLengths) override;
18391892

cpp/include/tensorrt_llm/batch_manager/kvCacheUtils.h

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ class BlockRange
7373
BaseKVCacheManager& cacheManager, BlockKey const& lastBlockKey, int32_t indexFromEnd)
7474
{
7575

76-
auto poolNum = cacheManager.getNumPools();
76+
auto poolNum = cacheManager.getBlockManager().getNumPools(
77+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
7778
TLLM_CHECK_WITH_INFO(poolNum == 1, "Reuse tree is not supported for multiple pools or variable window size");
7879

7980
auto windowSize = cacheManager.getBlockManager().getWindowSizesMetadata().begin()->first;
@@ -136,13 +137,21 @@ class BlockRange
136137
return blockHashesPerWindow;
137138
}
138139

139-
BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize) const
140+
BlockRangeForWindow getBlockRangeForWindow(SizeType32 windowSize, bool useIndexerKCache = false) const
140141
{
141142
TLLM_CHECK_WITH_INFO(
142143
mPoolsPerWindow.find(windowSize) != mPoolsPerWindow.end(), "Window size %d not found", windowSize);
143144
auto pool = mPoolsPerWindow.at(windowSize).front();
144145
auto blockIds = mBlockIdsPerWindow.at(windowSize);
145-
return BlockRangeForWindow(mManager, windowSize, std::move(blockIds), std::move(pool));
146+
if (useIndexerKCache)
147+
{
148+
TLLM_CHECK(mIndexerKCachePool);
149+
return BlockRangeForWindow(mManager, windowSize, std::move(blockIds), mIndexerKCachePool);
150+
}
151+
else
152+
{
153+
return BlockRangeForWindow(mManager, windowSize, std::move(blockIds), std::move(pool));
154+
}
146155
}
147156

148157
std::vector<SizeType32> getWindowSizes() const
@@ -167,9 +176,8 @@ class BlockRange
167176
, mRequestId(requestId)
168177
, mBlockIdsPerWindow(std::move(blockIdsPerWindow))
169178
{
170-
171-
// cacheManager.getBlockManager.getPrimaryPool(0);
172-
auto poolNum = mManager->getNumPools();
179+
auto poolNum = mManager->getBlockManager().getNumPools(
180+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
173181
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
174182
{
175183
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
@@ -181,21 +189,27 @@ class BlockRange
181189
: mManager(&cacheManager)
182190
, mRequestId(requestId)
183191
{
184-
auto poolNum = mManager->getNumPools();
192+
auto poolNum = mManager->getBlockManager().getNumPools(
193+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
185194
for (SizeType32 poolIdx = 0; poolIdx < poolNum; ++poolIdx)
186195
{
187196
auto windowSize = cacheManager.getBlockManager().getPoolWindowSize(poolIdx);
188197
mPoolsPerWindow[windowSize].push_back(cacheManager.getBlockManager().getPrimaryPool(poolIdx));
189198
mBlockIdsPerWindow[windowSize]
190199
= cacheManager.getSequence(mRequestId).getCacheBlockIds(windowSize).at(kFIRST_AND_ONLY_BEAM);
191200
}
201+
if (cacheManager.isEnableIndexerKCache())
202+
{
203+
mIndexerKCachePool = cacheManager.getIndexerKCachePool();
204+
}
192205
}
193206

194207
private:
195208
BaseKVCacheManager const* mManager;
196209
LlmRequest::RequestIdType const mRequestId;
197210
std::unordered_map<SizeType32, std::vector<SizeType32>> mBlockIdsPerWindow;
198211
std::unordered_map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> mPoolsPerWindow;
212+
runtime::ITensor::SharedPtr mIndexerKCachePool;
199213

200214
static constexpr SizeType32 kFIRST_AND_ONLY_BEAM = 0;
201215
static constexpr SizeType32 kFIRST_POOL_INDEX = 0;

cpp/include/tensorrt_llm/executor/dataTransceiverState.h

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class CacheState final
5050

5151
CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig,
5252
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
53-
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false)
53+
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableBlockReuse = false,
54+
bool hasIndexerKCache = false, SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
5455
: mModelConfig(std::move(modelConfig))
5556
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
5657
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
@@ -59,34 +60,45 @@ class CacheState final
5960
, mAttentionConfig(attentionType, kvFactor)
6061
{
6162
mEnableBlockReuse = enableBlockReuse;
63+
mHasIndexerKCache = hasIndexerKCache;
64+
mIndexerDimPerHead = indexerDimPerHead;
65+
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
6266
}
6367

6468
CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
6569
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
6670
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
6771
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
68-
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false)
72+
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false,
73+
SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
6974
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
7075
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
7176
attentionLayerNumPerPP}
7277
, mDataType{dataType}
7378
, mAttentionConfig(attentionType, kvFactor)
7479
{
7580
mEnableBlockReuse = enableBlockReuse;
81+
mHasIndexerKCache = hasIndexerKCache;
82+
mIndexerDimPerHead = indexerDimPerHead;
83+
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
7684
}
7785

7886
CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
7987
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
8088
std::vector<SizeType32> const& attentionLayerNumPerPP, nvinfer1::DataType dataType,
8189
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
82-
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false)
90+
int DPrank = 0, int DPsize = 0, bool enableBlockReuse = false, bool hasIndexerKCache = false,
91+
SizeType32 indexerDimPerHead = 0, SizeType32 indexerKCacheQuantBlockSize = 128)
8392
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
8493
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
8594
attentionLayerNumPerPP}
8695
, mDataType{dataType}
8796
, mAttentionConfig(attentionType, kvFactor)
8897
{
8998
mEnableBlockReuse = enableBlockReuse;
99+
mHasIndexerKCache = hasIndexerKCache;
100+
mIndexerDimPerHead = indexerDimPerHead;
101+
mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
90102
}
91103

92104
[[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept
@@ -174,6 +186,21 @@ class CacheState final
174186
return mEnableBlockReuse;
175187
}
176188

189+
[[nodiscard]] bool getHasIndexerKCache() const
190+
{
191+
return mHasIndexerKCache;
192+
}
193+
194+
[[nodiscard]] SizeType32 getIndexerDimPerHead() const
195+
{
196+
return mIndexerDimPerHead;
197+
}
198+
199+
[[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize() const
200+
{
201+
return mIndexerKCacheQuantBlockSize;
202+
}
203+
177204
[[nodiscard]] std::string toString() const
178205
{
179206
std::stringstream sstring;
@@ -194,6 +221,9 @@ class CacheState final
194221
sstring << "dpRank:" << mParallelConfig.mDPrank << "\n";
195222
sstring << "dpSize:" << mParallelConfig.mDPsize << "\n";
196223
sstring << "enableBlockReuse:" << mEnableBlockReuse << "\n";
224+
sstring << "hasIndexerKCache:" << mHasIndexerKCache << "\n";
225+
sstring << "indexerDimPerHead:" << mIndexerDimPerHead << "\n";
226+
sstring << "indexerKCacheQuantBlockSize:" << mIndexerKCacheQuantBlockSize << "\n";
197227
return sstring.str();
198228
}
199229

@@ -204,6 +234,9 @@ class CacheState final
204234
nvinfer1::DataType mDataType;
205235
AttentionConfig mAttentionConfig;
206236
bool mEnableBlockReuse{false};
237+
bool mHasIndexerKCache{false};
238+
SizeType32 mIndexerDimPerHead{0};
239+
SizeType32 mIndexerKCacheQuantBlockSize{128};
207240
};
208241

209242
struct MpiState

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager
4545
BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest,
4646
BlockKey const& lastBlockKey, int32_t indexFromEnd, bool recvSideHasCP)
4747
{
48-
auto poolNum = cacheManager->getBlockManager().getNumPools();
48+
auto poolNum = cacheManager->getBlockManager().getNumPools(
49+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
50+
4951
// Note: When recv side has CP, the requested seqLen is lesser than seqLen on the sender side as seqLen is
5052
// distributed among CP ranks. So, we transfer all blocks from send side.
5153
if (poolNum > 1 || !cacheManager->isEnableBlockReuse() || lastBlockKey.uniqueTokens.size() == 0 || recvSideHasCP)
@@ -88,8 +90,9 @@ BlockRange getBlockRangeForSending(BaseKVCacheManager* cacheManager, LlmRequest
8890
BlockRange getBlockRangeForReceiving(
8991
BaseKVCacheManager* cacheManager, LlmRequest const& llmRequest, bool srcEnableBlockReuse, bool recvSideHasCP)
9092
{
91-
auto poolNum = cacheManager->getBlockManager().getNumPools();
9293
// Note: When recv side has CP, we request all blocks from send side right now.
94+
auto poolNum = cacheManager->getBlockManager().getNumPools(
95+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
9396
if (poolNum == 1 && srcEnableBlockReuse && !recvSideHasCP)
9497
{
9598
// Build from all block ids, then slice off the reused blocks so we only transfer newly allocated ones.
@@ -171,7 +174,8 @@ void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::
171174
// if gen PP and context PP are different, cache formatter only support alternative window like gpt-oss.
172175
// which is one layer is WSA, and another layer is Full attention.
173176

174-
auto numPools = cacheManager->getBlockManager().getNumPools();
177+
auto numPools = cacheManager->getBlockManager().getNumPools(
178+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
175179
auto layerNum = cacheManager->getBlockManager().getNumLayers();
176180

177181
auto selfPPNum = selfConfig.getParallelConfig().mPipelineParallelism;
@@ -248,7 +252,8 @@ void CacheFormatter::format(tensorrt_llm::batch_manager::TransferSession& sessio
248252
auto& blockManager = mCacheManager->getBlockManager();
249253
auto const& lastBlockKey = session.getLastBlockKey();
250254
auto blockRange = getBlockRangeForSending(mCacheManager, llmRequest, lastBlockKey, indexFromEnd);
251-
auto const numPools = blockManager.getNumPools();
255+
auto const numPools
256+
= blockManager.getNumPools(/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
252257
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
253258

254259
bool layerWise = common::getEnvDisaggLayerwise() && numPools == 1;
@@ -556,7 +561,8 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
556561
TLLM_LOG_DEBUG("pickUpConnections size: %d connections size: %d", pickUpConnections.size(), connections.size());
557562
std::vector<runtime::ITensor::SharedPtr> recvBufferTmps;
558563
std::map<SizeType32, std::vector<runtime::ITensor::SharedPtr>> outputBuffersPerWindow;
559-
auto const numPools = mCacheManager->getBlockManager().getNumPools();
564+
auto const numPools = mCacheManager->getBlockManager().getNumPools(
565+
/*includeBlockScalePools=*/false, /*includeIndexerKCachePools=*/false);
560566
// TODO(oargov): are we sure the other side has the same number of pools? this might not hold for pp_size>1...
561567
size_t blockNum = 0;
562568
size_t cacheBlockSizeSum = 0;
@@ -966,13 +972,14 @@ void CacheFormatter::unformat(tensorrt_llm::batch_manager::TransferSession& sess
966972
}
967973

968974
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(
969-
BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager, bool isMLA)
975+
BaseKVCacheManager* cacheManager, std::vector<CacheTransBufferManager*> const& cacheTransBufferManagers, bool isMLA)
970976
{
977+
TLLM_CHECK(!cacheTransBufferManagers.empty());
971978
if (isMLA)
972979
{
973-
return std::make_unique<MLACacheFormatter>(cacheManager, cacheTransBufferManager);
980+
return std::make_unique<MLACacheFormatter>(cacheManager, cacheTransBufferManagers);
974981
}
975-
return std::make_unique<CacheFormatter>(cacheManager, cacheTransBufferManager);
982+
return std::make_unique<CacheFormatter>(cacheManager, cacheTransBufferManagers[0]);
976983
}
977984

978985
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class CacheFormatter final : public BaseCacheFormatter
133133
CacheTransBufferManager* mCacheTransBufferManager;
134134
};
135135

136-
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(
137-
BaseKVCacheManager* cacheManager, CacheTransBufferManager* cacheTransBufferManager, bool isMLA = false);
136+
std::unique_ptr<BaseCacheFormatter> createCacheFormatter(BaseKVCacheManager* cacheManager,
137+
std::vector<CacheTransBufferManager*> const& cacheTransBufferManagers, bool isMLA = false);
138138

139139
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

0 commit comments

Comments
 (0)