Skip to content

Commit e8d4a56

Browse files
authored
[None][fix] fix eagle3 accuracy issue on sm120 (#8944)
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
1 parent a7033a9 commit e8d4a56

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
375375
.mask = reinterpret_cast<SpecDecParams::MaskType const*>(xqaParams.spec_decoding_packed_mask)};
376376
};
377377

378-
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 15;
378+
constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 16;
379379
uint32_t idxNextParam = 0;
380380
void* kernelParams[kMAX_NB_KERNEL_PARAMS];
381381
auto appendParam = [&](auto* p) mutable

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,9 +1076,9 @@ def update_spec_dec_param(
10761076
spec_decoding_position_offsets = None
10771077
spec_decoding_packed_mask = None
10781078
spec_decoding_generation_lengths = None
1079-
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
1080-
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
1081-
) < 100
1079+
# spec_dec mode should only be enabled for non-sm100 machines and when there's a spec-dec tree.
1080+
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
1081+
get_sm_version() < 100 or get_sm_version() == 120)
10821082

10831083
if get_sm_version() >= 100:
10841084
if is_spec_dec_tree or is_spec_dec_dynamic_tree:

0 commit comments

Comments
 (0)