@@ -50,7 +50,8 @@ class CacheState final
5050
5151 CacheState (ModelConfig modelConfig, runtime::WorldConfig const & worldConfig,
5252 std::vector<SizeType32> const & attentionLayerNumPerPP, nvinfer1::DataType dataType,
53- AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableBlockReuse = false )
53+ AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableBlockReuse = false ,
54+ bool hasIndexerKCache = false , SizeType32 indexerDimPerHead = 0 , SizeType32 indexerKCacheQuantBlockSize = 128 )
5455 : mModelConfig (std::move(modelConfig))
5556 , mParallelConfig {worldConfig.getTensorParallelism (), worldConfig.getPipelineParallelism (),
5657 worldConfig.getContextParallelism (), worldConfig.enableAttentionDP (), worldConfig.getTensorParallelRank (),
@@ -59,34 +60,45 @@ class CacheState final
5960 , mAttentionConfig (attentionType, kvFactor)
6061 {
6162 mEnableBlockReuse = enableBlockReuse;
63+ mHasIndexerKCache = hasIndexerKCache;
64+ mIndexerDimPerHead = indexerDimPerHead;
65+ mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
6266 }
6367
6468 CacheState (std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
6569 SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
6670 std::vector<SizeType32> const & attentionLayerNumPerPP, nvinfer1::DataType dataType,
6771 AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableAttentionDP = false ,
68- int DPrank = 0 , int DPsize = 0 , bool enableBlockReuse = false )
72+ int DPrank = 0 , int DPsize = 0 , bool enableBlockReuse = false , bool hasIndexerKCache = false ,
73+ SizeType32 indexerDimPerHead = 0 , SizeType32 indexerKCacheQuantBlockSize = 128 )
6974 : mModelConfig {std::move (nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
7075 , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
7176 attentionLayerNumPerPP}
7277 , mDataType {dataType}
7378 , mAttentionConfig (attentionType, kvFactor)
7479 {
7580 mEnableBlockReuse = enableBlockReuse;
81+ mHasIndexerKCache = hasIndexerKCache;
82+ mIndexerDimPerHead = indexerDimPerHead;
83+ mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
7684 }
7785
7886 CacheState (SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
7987 SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
8088 std::vector<SizeType32> const & attentionLayerNumPerPP, nvinfer1::DataType dataType,
8189 AttentionType attentionType = AttentionType::kDEFAULT , int kvFactor = 2 , bool enableAttentionDP = false ,
82- int DPrank = 0 , int DPsize = 0 , bool enableBlockReuse = false )
90+ int DPrank = 0 , int DPsize = 0 , bool enableBlockReuse = false , bool hasIndexerKCache = false ,
91+ SizeType32 indexerDimPerHead = 0 , SizeType32 indexerKCacheQuantBlockSize = 128 )
8392 : mModelConfig {std::vector (nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
8493 , mParallelConfig {tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize,
8594 attentionLayerNumPerPP}
8695 , mDataType {dataType}
8796 , mAttentionConfig (attentionType, kvFactor)
8897 {
8998 mEnableBlockReuse = enableBlockReuse;
99+ mHasIndexerKCache = hasIndexerKCache;
100+ mIndexerDimPerHead = indexerDimPerHead;
101+ mIndexerKCacheQuantBlockSize = indexerKCacheQuantBlockSize;
90102 }
91103
92104 [[nodiscard]] bool operator ==(kv_cache::CacheState const & other) const noexcept
@@ -174,6 +186,21 @@ class CacheState final
174186 return mEnableBlockReuse ;
175187 }
176188
189+ [[nodiscard]] bool getHasIndexerKCache () const
190+ {
191+ return mHasIndexerKCache ;
192+ }
193+
194+ [[nodiscard]] SizeType32 getIndexerDimPerHead () const
195+ {
196+ return mIndexerDimPerHead ;
197+ }
198+
199+ [[nodiscard]] SizeType32 getIndexerKCacheQuantBlockSize () const
200+ {
201+ return mIndexerKCacheQuantBlockSize ;
202+ }
203+
177204 [[nodiscard]] std::string toString () const
178205 {
179206 std::stringstream sstring;
@@ -194,6 +221,9 @@ class CacheState final
194221 sstring << " dpRank:" << mParallelConfig .mDPrank << " \n " ;
195222 sstring << " dpSize:" << mParallelConfig .mDPsize << " \n " ;
196223 sstring << " enableBlockReuse:" << mEnableBlockReuse << " \n " ;
224+ sstring << " hasIndexerKCache:" << mHasIndexerKCache << " \n " ;
225+ sstring << " indexerDimPerHead:" << mIndexerDimPerHead << " \n " ;
226+ sstring << " indexerKCacheQuantBlockSize:" << mIndexerKCacheQuantBlockSize << " \n " ;
197227 return sstring.str ();
198228 }
199229
@@ -204,6 +234,9 @@ class CacheState final
204234 nvinfer1::DataType mDataType ;
205235 AttentionConfig mAttentionConfig ;
206236 bool mEnableBlockReuse {false };
237+ bool mHasIndexerKCache {false };
238+ SizeType32 mIndexerDimPerHead {0 };
239+ SizeType32 mIndexerKCacheQuantBlockSize {128 };
207240};
208241
209242struct MpiState
0 commit comments