2020from dataclasses import dataclass
2121from functools import cached_property
2222from itertools import repeat
23- from typing import Any , Callable , List , Optional , TypeVar , cast
23+ from typing import Any , Callable , List , Optional , Type , TypeVar , cast
2424
2525import numpy as np
2626import torch
5555from tensorrt_llm .mapping import Mapping
5656from tensorrt_llm .sampling_params import SamplingParams
5757
58+ from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
5859from ..speculative .spec_tree_manager import SpecTreeManager
5960from .finish_reason import FinishedState
6061from .llm_request import LlmRequest , LlmRequestState , get_draft_token_length
6162from .resource_manager import ResourceManager , ResourceManagerType
6263from .sampling_utils import (
6364 GREEDY ,
6465 GenericStrategyKeyType ,
66+ GroupedStrategySampler ,
6567 SimpleGroupedStrategySampler ,
6668 Strategy ,
6769 UtilsSamplingParams ,
@@ -268,7 +270,7 @@ def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:
268270def _group_requests_by_strategy_key (
269271 requests : Iterable [LlmRequest ],
270272 * ,
271- strategy_to_key : Callable [[Strategy ], GenericStrategyKeyType ],
273+ strategy_to_key : Callable [[Strategy , bool ], GenericStrategyKeyType ],
272274 pin_memory : bool = False ,
273275 vocab_size : int ,
274276) -> dict [tuple [GenericStrategyKeyType , bool ], tuple [torch .Tensor , List [Strategy ]]]:
@@ -278,8 +280,8 @@ def _group_requests_by_strategy_key(
278280 )
279281 for req_index , req in enumerate (requests ):
280282 strategy = _request_strategy (req , vocab_size = vocab_size )
281- strategy_key = strategy_to_key (strategy )
282283 speculation_needs_probs = req .py_draft_logits is not None and strategy is not GREEDY
284+ strategy_key = strategy_to_key (strategy , speculation_needs_probs )
283285 group_dict_entry = group_dict [(strategy_key , speculation_needs_probs )]
284286 group_dict_entry [0 ].append (req_index )
285287 group_dict_entry [1 ].append (strategy )
@@ -608,6 +610,7 @@ class Args:
608610 max_num_sequences : int
609611 max_beam_width : int
610612 max_total_draft_tokens : int
613+ disable_flash_infer_sampling : bool = False
611614
612615 def __init__ (self , args : Args ):
613616 self .max_seq_len = args .max_seq_len
@@ -642,6 +645,14 @@ def __init__(self, args: Args):
642645 ] # `in FinishReason` clashes with PyBind11: `TypeError: 'pybind11_type' object is not iterable`
643646 }
644647
648+ self ._grouped_sampler_cls : Type [GroupedStrategySampler ]
649+ if IS_FLASHINFER_AVAILABLE and not args .disable_flash_infer_sampling :
650+ from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler
651+
652+ self ._grouped_sampler_cls = FlashInferGroupedStrategySampler
653+ else :
654+ self ._grouped_sampler_cls = SimpleGroupedStrategySampler
655+
645656 # Initialize seed for multi-GPU consistency
646657 self ._global_seed = 42
647658 self ._generator = None
@@ -1251,7 +1262,7 @@ def _sample_batched_by_strategy(
12511262 requests ,
12521263 pin_memory = True ,
12531264 vocab_size = logits_cuda .size (1 ),
1254- strategy_to_key = SimpleGroupedStrategySampler .strategy_grouping_key ,
1265+ strategy_to_key = self . _grouped_sampler_cls .strategy_grouping_key ,
12551266 )
12561267 generator_cuda = self .get_generator (cuda_device )
12571268
@@ -1308,7 +1319,7 @@ def _sample_batched_by_strategy(
13081319 for _ in range (steps )
13091320 ]
13101321 group_next_tokens_cuda , group_softmax_cuda = (
1311- SimpleGroupedStrategySampler .sample_grouped_strategies (
1322+ self . _grouped_sampler_cls .sample_grouped_strategies (
13121323 strategy_key ,
13131324 group_strategies_per_step ,
13141325 group_logits_cuda ,
0 commit comments