Skip to content

Commit 9e6c4b4

Browse files
authored
Merge branch 'main' into user/fanrongl/mtp3_support_for_ds32
2 parents df53eb5 + 979b3ae commit 9e6c4b4

File tree

10 files changed

+667
-31
lines changed

10 files changed

+667
-31
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ tensorrt_llm/pg_utils_bindings.*.so
5050
tensorrt_llm/flash_mla/
5151
tensorrt_llm/flash_mla_cpp_tllm.*.so
5252
tensorrt_llm/flash_mla_cpp_tllm.pyi
53+
tensorrt_llm/scripts
5354
*docs/cpp_docs*
5455
*docs/source/_cpp_gen*
5556
docs/source/**/*.rst

scripts/check_test_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def main():
261261
print(
262262
"Duplicate test names found in waives.txt, please delete one or combine them first!!!\n"
263263
)
264-
pass_flag = False
264+
# pass_flag = False
265265

266266
non_existent_cases_file = os.path.join(llm_src, "nonexits_cases.json")
267267
if os.path.isfile(non_existent_cases_file) and os.path.getsize(

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,8 @@ def create_py_executor_instance(
823823
def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
824824
max_batch_size: int,
825825
speculative_config: SpeculativeConfig,
826-
max_beam_width: int):
826+
max_beam_width: int,
827+
disable_flash_infer_sampling: bool):
827828
max_num_sequences = max_batch_size * mapping.pp_size
828829
max_draft_len = (0 if speculative_config is None else
829830
speculative_config.max_draft_len)
@@ -836,20 +837,32 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
836837
max_total_draft_tokens=max_total_draft_tokens,
837838
max_num_sequences=max_num_sequences,
838839
max_beam_width=max_beam_width,
840+
disable_flash_infer_sampling=disable_flash_infer_sampling,
839841
)
840842

841843

842844
def instantiate_sampler(
843-
engine: PyTorchModelEngine, llm_args: TorchLlmArgs, mapping: Mapping,
844-
max_batch_size: int, max_beam_width: int, max_seq_len: int,
845-
mm_encoder_only: bool, speculative_config: SpeculativeConfig,
846-
decoding_config: trtllm.DecodingConfig, kv_cache_config: KvCacheConfig):
845+
engine: PyTorchModelEngine,
846+
llm_args: TorchLlmArgs,
847+
mapping: Mapping,
848+
*,
849+
max_batch_size: int,
850+
max_beam_width: int,
851+
max_seq_len: int,
852+
mm_encoder_only: bool,
853+
speculative_config: SpeculativeConfig,
854+
decoding_config: trtllm.DecodingConfig,
855+
kv_cache_config: KvCacheConfig,
856+
disable_flash_infer_sampling: bool,
857+
):
847858
sampler_args = create_torch_sampler_args(
848859
mapping,
849860
max_seq_len=engine.max_seq_len,
850861
max_batch_size=max_batch_size,
851862
speculative_config=speculative_config,
852-
max_beam_width=max_beam_width)
863+
max_beam_width=max_beam_width,
864+
disable_flash_infer_sampling=disable_flash_infer_sampling,
865+
)
853866
decoding_mode = get_decoding_mode(decoding_config=decoding_config,
854867
max_beam_width=max_beam_width)
855868
if mapping.cp_config.get('cp_type') == CpType.STAR:

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -493,16 +493,19 @@ def drafting_loop_wrapper(model):
493493
)
494494

495495
with allocation_scope(ExecutorMemoryType.SAMPLER, RestoreMode.PINNED):
496-
sampler = instantiate_sampler(model_engine,
497-
llm_args,
498-
mapping,
499-
max_batch_size=max_batch_size,
500-
max_beam_width=max_beam_width,
501-
max_seq_len=max_seq_len,
502-
mm_encoder_only=mm_encoder_only,
503-
speculative_config=spec_config,
504-
decoding_config=decoding_config,
505-
kv_cache_config=kv_cache_config)
496+
sampler = instantiate_sampler(
497+
model_engine,
498+
llm_args,
499+
mapping,
500+
max_batch_size=max_batch_size,
501+
max_beam_width=max_beam_width,
502+
max_seq_len=max_seq_len,
503+
mm_encoder_only=mm_encoder_only,
504+
speculative_config=spec_config,
505+
decoding_config=decoding_config,
506+
kv_cache_config=kv_cache_config,
507+
disable_flash_infer_sampling=llm_args._disable_flash_infer_sampling,
508+
)
506509
logger.info(f"Using Sampler: {type(sampler).__name__}")
507510

508511
if kv_connector_config is not None:

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dataclasses import dataclass
2121
from functools import cached_property
2222
from itertools import repeat
23-
from typing import Any, Callable, List, Optional, TypeVar, cast
23+
from typing import Any, Callable, List, Optional, Type, TypeVar, cast
2424

2525
import numpy as np
2626
import torch
@@ -55,13 +55,15 @@
5555
from tensorrt_llm.mapping import Mapping
5656
from tensorrt_llm.sampling_params import SamplingParams
5757

58+
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
5859
from ..speculative.spec_tree_manager import SpecTreeManager
5960
from .finish_reason import FinishedState
6061
from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length
6162
from .resource_manager import ResourceManager, ResourceManagerType
6263
from .sampling_utils import (
6364
GREEDY,
6465
GenericStrategyKeyType,
66+
GroupedStrategySampler,
6567
SimpleGroupedStrategySampler,
6668
Strategy,
6769
UtilsSamplingParams,
@@ -268,7 +270,7 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
268270
def _group_requests_by_strategy_key(
269271
requests: Iterable[LlmRequest],
270272
*,
271-
strategy_to_key: Callable[[Strategy], GenericStrategyKeyType],
273+
strategy_to_key: Callable[[Strategy, bool], GenericStrategyKeyType],
272274
pin_memory: bool = False,
273275
vocab_size: int,
274276
) -> dict[tuple[GenericStrategyKeyType, bool], tuple[torch.Tensor, List[Strategy]]]:
@@ -278,8 +280,8 @@ def _group_requests_by_strategy_key(
278280
)
279281
for req_index, req in enumerate(requests):
280282
strategy = _request_strategy(req, vocab_size=vocab_size)
281-
strategy_key = strategy_to_key(strategy)
282283
speculation_needs_probs = req.py_draft_logits is not None and strategy is not GREEDY
284+
strategy_key = strategy_to_key(strategy, speculation_needs_probs)
283285
group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)]
284286
group_dict_entry[0].append(req_index)
285287
group_dict_entry[1].append(strategy)
@@ -608,6 +610,7 @@ class Args:
608610
max_num_sequences: int
609611
max_beam_width: int
610612
max_total_draft_tokens: int
613+
disable_flash_infer_sampling: bool = False
611614

612615
def __init__(self, args: Args):
613616
self.max_seq_len = args.max_seq_len
@@ -642,6 +645,14 @@ def __init__(self, args: Args):
642645
] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
643646
}
644647

648+
self._grouped_sampler_cls: Type[GroupedStrategySampler]
649+
if IS_FLASHINFER_AVAILABLE and not args.disable_flash_infer_sampling:
650+
from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler
651+
652+
self._grouped_sampler_cls = FlashInferGroupedStrategySampler
653+
else:
654+
self._grouped_sampler_cls = SimpleGroupedStrategySampler
655+
645656
# Initialize seed for multi-GPU consistency
646657
self._global_seed = 42
647658
self._generator = None
@@ -1251,7 +1262,7 @@ def _sample_batched_by_strategy(
12511262
requests,
12521263
pin_memory=True,
12531264
vocab_size=logits_cuda.size(1),
1254-
strategy_to_key=SimpleGroupedStrategySampler.strategy_grouping_key,
1265+
strategy_to_key=self._grouped_sampler_cls.strategy_grouping_key,
12551266
)
12561267
generator_cuda = self.get_generator(cuda_device)
12571268

@@ -1308,7 +1319,7 @@ def _sample_batched_by_strategy(
13081319
for _ in range(steps)
13091320
]
13101321
group_next_tokens_cuda, group_softmax_cuda = (
1311-
SimpleGroupedStrategySampler.sample_grouped_strategies(
1322+
self._grouped_sampler_cls.sample_grouped_strategies(
13121323
strategy_key,
13131324
group_strategies_per_step,
13141325
group_logits_cuda,

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@
3333
from typing_extensions import override
3434

3535

36-
TemperatureOnly = tuple[Literal["temperature"], float]
37-
TopK = tuple[Literal["top_k"], int, float]
38-
TopP = tuple[Literal["top_p"], float, float]
39-
TopKTopP = tuple[Literal["top_k_top_p"], int, float, float]
40-
Greedy = tuple[Literal["greedy"], None]
36+
TemperatureOnly: TypeAlias = tuple[Literal["temperature"], float]
37+
TopK: TypeAlias = tuple[Literal["top_k"], int, float]
38+
TopP: TypeAlias = tuple[Literal["top_p"], float, float]
39+
TopKTopP: TypeAlias = tuple[Literal["top_k_top_p"], int, float, float]
40+
Greedy: TypeAlias = tuple[Literal["greedy"], None]
4141
GREEDY: Greedy = ("greedy", None)
42-
Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
42+
Strategy: TypeAlias = TopK | TopP | Greedy | TopKTopP | TemperatureOnly
4343

4444

4545
@dataclass(frozen=True, kw_only=True)
@@ -258,7 +258,10 @@ def sample(
258258
match strategy:
259259
case ("top_k", top_k, temperature):
260260
tokens, softmax = top_k_sampling_batch(
261-
logits, top_k=top_k, temperature=temperature, generator=generator
261+
logits,
262+
top_k=top_k,
263+
temperature=temperature,
264+
generator=generator,
262265
)
263266
case ("top_p", top_p, temperature):
264267
tokens, softmax = top_p_sampling_batch(
@@ -292,7 +295,7 @@ def sample(
292295
class GroupedStrategySampler(Generic[GenericStrategyKeyType], abc.ABC):
293296
@staticmethod
294297
@abc.abstractmethod
295-
def strategy_grouping_key(strategy: Strategy) -> GenericStrategyKeyType:
298+
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> GenericStrategyKeyType:
296299
raise NotImplementedError
297300

298301
@staticmethod
@@ -314,7 +317,7 @@ class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):
314317

315318
@override
316319
@staticmethod
317-
def strategy_grouping_key(strategy: Strategy) -> STRATEGY_KEY_TYPE:
320+
def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KEY_TYPE:
318321
return strategy
319322

320323
@override

0 commit comments

Comments
 (0)