From aa376d65bcb2eb9bde502d6b37e2e3d50128e330 Mon Sep 17 00:00:00 2001 From: ixlmar <206748156+ixlmar@users.noreply.github.com> Date: Tue, 11 Nov 2025 17:39:12 +0000 Subject: [PATCH] test: ensure sampling is async Signed-off-by: ixlmar <206748156+ixlmar@users.noreply.github.com> --- .../_torch/sampler/test_torch_sampler.py | 1075 ++++++++++------- tests/unittest/utils/util.py | 35 +- 2 files changed, 677 insertions(+), 433 deletions(-) diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index 8072e638ea7..75a17044bca 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -12,11 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from dataclasses import dataclass from itertools import product -from random import shuffle as shuffle_inplace -from typing import Callable, Final, Generator, Optional, Type, Union, cast +from typing import ( + Callable, + ContextManager, + Final, + Generator, + Optional, + Protocol, + Type, + TypeVar, + Union, + cast, +) import flashinfer.sampling import numpy as np @@ -344,6 +354,52 @@ def test_should_provide_draft_probs_consistency( assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy) +class UutProvider(Protocol): + def __call__(self, is_warmup: bool) -> ContextManager[Callable[[], None]]: ... + + +def _run_test_with_warmup( + uut_provider: UutProvider, + warmup_sizes_bytes: tuple[int] = (4 * 2**30,), + max_sync_s: Optional[float] = None, +): + """Run UUT including setup and warmup. + + This is mainly used to check that the UUT does not CUDA device sync. Thus, + given that PyTorch's caching memory allocator can device sync when it runs + out of cached GPU memory segments, the warmup allocates some GPU memory. + + The warmup also runs the test once. This avoids issues with things like lazy loading + of device code. The UUT provider can use the 'is_warmup' argument to adapt its + behavior to the warmup and final test runs. + + If max_sync_s is provided, this helper checks that the UUT does not device sync, + assuming that the sync (CPU) part of the code takes no longer than max_sync_s + seconds to complete. + + It is the user's responsibility to ensure that the amount of submitted work + does not exceed the CUDA driver/device queue capacity, which would make + the execution appear synchronous. + """ + with torch.cuda.Stream(): + with uut_provider(is_warmup=True) as uut: + bufs = [] + for warmup_size in warmup_sizes_bytes: + bufs.append( + torch.ones(warmup_size, device=torch.cuda.current_device(), dtype=torch.int8) + ) + del bufs + uut() + + with uut_provider(is_warmup=False) as uut: + with ( + assert_no_cuda_sync(sync_timeout_s=max_sync_s) + if max_sync_s is not None + else nullcontext() + ): + uut() + + @force_ampere @pytest.mark.parametrize( "draft_len, with_ctx, with_gen", @@ -363,7 +419,7 @@ def test_select_generated_logits(draft_len: int, with_ctx: bool, with_gen: bool) device = torch.device("cuda") @contextmanager - def _test_runner() -> Generator[Callable[[], None], None, None]: + def _test_runner(is_warmup: bool) -> Generator[Callable[[], None], None, None]: class ContextRequestMock: def __init__(self, return_context_logits: bool): self._return_context_logits = return_context_logits @@ -499,18 +555,7 @@ def _uut(res=res): selected_logits = res.result.selected_logits torch.testing.assert_close(selected_logits.to("cpu"), all_logits[expected_logit_indices]) - with _test_runner() as uut: - # Pre-allocates a large chunk of memory, because PyTorch caching memory allocator - # can sync otherwise. - buf = torch.ones((2**30,), device=device) - del buf - # Warmup to avoid syncs due to lazy loading of kernels - uut() - - with torch.cuda.Stream(): - with _test_runner() as uut: - with assert_no_cuda_sync(): - uut() + _run_test_with_warmup(_test_runner, max_sync_s=0.3) MAX_NUM_SEQUENCES = 128 @@ -817,9 +862,9 @@ class VaryParams: assert constrained_indices is not None no_shuffle_start_idx, no_shuffle_end_idx = constrained_indices head_shuffled = mixed_params_list[:no_shuffle_start_idx] - shuffle_inplace(head_shuffled) + rng.shuffle(head_shuffled) # inplace tail_shuffled = mixed_params_list[no_shuffle_end_idx:] - shuffle_inplace(tail_shuffled) + rng.shuffle(tail_shuffled) # inplace mixed_params_list = ( head_shuffled + mixed_params_list[no_shuffle_start_idx:no_shuffle_end_idx] @@ -827,21 +872,21 @@ class VaryParams: ) label += "_oneContiguous" elif isinstance(constraint_value, Shuffle): - shuffle_inplace(mixed_params_list) + rng.shuffle(mixed_params_list) # inplace label += "_shuffled" elif isinstance(constraint_value, VaryParams): - shuffle_inplace(mixed_params_list) + rng.shuffle(mixed_params_list) # inplace def _perturb_params(param: SamplingParams): top_k = param.top_k if top_k is not None: - top_k = rng.integers(vocab_size // 3) + top_k = rng.integers(2, vocab_size // 3) top_p = param.top_p if top_p is not None: - top_p *= rng.random() + top_p *= max(rng.random(), 1e-6) temperature = param.temperature if temperature is not None: - temperature *= rng.random() + temperature *= max(rng.random(), 1e-6) return SamplingParams( top_p=top_p, top_k=top_k, @@ -1065,22 +1110,41 @@ def _sample( model_outputs: dict[str, torch.Tensor], *, num_repeats: Optional[int] = None, + allow_sync: bool = True, ) -> torch.Tensor: """Call sample_async. Optionally, run sampling repeatedly, e.g., to gather statistics. """ assert not scheduled_requests.context_requests - # FIXME: Currently, sample_async is not fully async (TRTLLM-9175) - # with assert_no_cuda_sync(sync_timeout_s=(0.01 * (num_repeats or 1) + 0.25)): + num_actual_repeats = num_repeats if num_repeats is not None else 1 + + T = TypeVar("T") + is_first = True + + def maybe_check_no_sync(func: Callable[[], T]) -> T: + # The device-side sleep submitted by assert_no_cuda_sync blocks CUDA operations + # once the amount of enqueued work becomes large enough. + # Only checking the first sampling repetition to avoid this. + nonlocal is_first + with ( + assert_no_cuda_sync(sync_timeout_s=0.25) + if (not allow_sync and is_first) + else nullcontext() + ): + is_first = False + return func() + sample_states = [ - sampler.sample_async( - scheduled_requests, - model_outputs=model_outputs, - num_context_logits_prefix_sum=[0], - resource_manager=None, # only used for tree sampling, which is not tested here + maybe_check_no_sync( + lambda: sampler.sample_async( + scheduled_requests, + model_outputs=model_outputs, + num_context_logits_prefix_sum=[0], + resource_manager=None, # only used for tree sampling, which is not tested here + ) ) - for _ in range(num_repeats if num_repeats is not None else 1) + for _ in range(num_actual_repeats) ] new_tokens_tensors = [] for sample_state in sample_states: @@ -1159,6 +1223,9 @@ def test_probs( vocab_size: int, params_label: str, allow_zero_draft_len: bool, # used by fixtures + sampling_params_list: list[SamplingParams], + seq_slot_assignment: tuple[list[int], int], + with_draft_logits: bool, ): """Validate probabilities returned by sample_async. @@ -1168,7 +1235,9 @@ def test_probs( This test checks that the presence of py_target_probs behaves as expected and validates the values of this attribute (when present). """ - with torch.cuda.Stream(): + + @contextmanager + def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]: torch.manual_seed(42) # torch.testing.make_tensor does not accept Generator strategy_tags = { @@ -1181,14 +1250,31 @@ def test_probs( ] } - _ = self._sample( - sampler, - scheduled_requests=mock_requests, - model_outputs=model_outputs, - ) + if is_warmup: + # Use separate requests for warmup, because prob outputs are attached to + # requests. + uut_mock_requests = self._build_mock_requests( + sampling_params_list=sampling_params_list, + vocab_size=vocab_size, + seq_slot_assignment=seq_slot_assignment, + with_draft_logits=with_draft_logits, + draft_lens=draft_lens, + ) + else: + uut_mock_requests = mock_requests + + def _uut(): + _ = self._sample( + sampler, + scheduled_requests=uut_mock_requests, + model_outputs=model_outputs, + allow_sync=is_warmup, + ) + + yield _uut logit_offset = 0 - for req, draft_len in zip(mock_requests.all_requests(), draft_lens): + for req, draft_len in zip(uut_mock_requests.all_requests(), draft_lens): assert req.py_target_probs is not None probs = req.py_target_probs.cpu() assert probs.shape == (draft_len + 1, vocab_size) @@ -1310,6 +1396,452 @@ def test_probs( logit_offset += steps + _run_test_with_warmup(_uut_provider) + + def _compute_probs( + self, + *, + use_flashinfer: bool, + model_outputs: dict[str, torch.Tensor], + sampling_params_list: list[SamplingParams], + seq_slot_assignment: tuple[list[int], int], + vocab_size: int, + max_draft_len: int, + draft_lens: list[int], + ) -> ScheduledRequests: + """Construct a batch of requests with given sampling params and invoke sampler to compute probs. + + The probs (PMFs) corresponding to the provided model_outputs and sampling_params_list are returned + in the py_target_probs attribute of the returned requests. + + Used by test_samples. + """ + # Because max_draft_len can be zero and probs are not computed in this case, + # a separate sampler instance (with larger max_draft_len) is needed to + # compute probs in general. + draft_len_with_probs = max(1, max_draft_len) + sampler_with_probs = self._build_sampler( + use_flashinfer=use_flashinfer, + max_draft_len=draft_len_with_probs, + seq_slot_assignment=seq_slot_assignment, + ) + mock_requests_with_probs = self._build_mock_requests( + sampling_params_list=sampling_params_list, + vocab_size=vocab_size, + seq_slot_assignment=seq_slot_assignment, + # NB: with_draft_logits=True and non-zero draft len ensures that + # LlmRequest.py_target_probs is set. + with_draft_logits=True, + draft_lens=([draft_len_with_probs] * len(sampling_params_list)), + ) + # zero-pad logits to draft_len_with_probs + logits = model_outputs["logits"] + logits_offset = 0 + steps_with_probs = draft_len_with_probs + 1 + logits_with_probs = torch.zeros( + (steps_with_probs * len(mock_requests_with_probs.all_requests()), vocab_size), + dtype=logits.dtype, + device=logits.device, + ) + for req_idx, draft_len in enumerate(draft_lens): + steps = draft_len + 1 + logits_with_probs[ + (req_idx * steps_with_probs) : (req_idx * steps_with_probs + steps) + ] = logits[logits_offset : (logits_offset + steps)] + logits_offset += steps + model_outputs_with_probs = model_outputs.copy() + model_outputs_with_probs["logits"] = logits_with_probs + _ = self._sample( + sampler_with_probs, + scheduled_requests=mock_requests_with_probs, + model_outputs=model_outputs_with_probs, + ) + return mock_requests_with_probs + + @staticmethod + def _inject_batching_check( + patch_ctx: pytest.MonkeyPatch, + *, + sampler: TorchSampler, + use_flashinfer: bool, + ): + """Setup interception of sample_async and request grouping. + + If FlashInfer.sampling is used, this validates that at every + invocation of sample_async, the sampling backend is called at most + once for any given sampling strategy (if FlashInfer.sampling is used). + + Used by test_samples. + """ + # FlashInfer sampling batches requests of the same kind (e.g. top-p) + # together even if they have different parameter values (e.g. probability thresholds). + # This variable tracks which request types have been encountered. + flashinfer_keys_seen = set() + + if use_flashinfer: + sample_grouped_strategies_orig = sampler._grouped_sampler_cls.sample_grouped_strategies + + def _sample_grouped_strategies( + group_key: FlashInferGroupedStrategySampler.STRATEGY_KEY_TYPE, + strategies: list[Strategy], + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + return_probs: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + assert issubclass(group_key, sampling_utils_flashinfer._StrategyImpls.StrategyImpl) + assert generator is sampler.get_generator(logits.device) + nonlocal flashinfer_keys_seen + assert group_key not in flashinfer_keys_seen + flashinfer_keys_seen.add(group_key) + return sample_grouped_strategies_orig( + group_key, + strategies, + logits, + group_logit_indices=group_logit_indices, + generator=generator, + return_probs=return_probs, + ) + + patch_ctx.setattr( + sampler._grouped_sampler_cls, + "sample_grouped_strategies", + _sample_grouped_strategies, + ) + + sample_async_orig = sampler.sample_async + + def _sample_async( + scheduled_requests: ScheduledRequests, + model_outputs: dict[str, torch.Tensor], + num_context_logits_prefix_sum: list[int], + resource_manager=None, + ): + nonlocal flashinfer_keys_seen + flashinfer_keys_seen.clear() + res = sample_async_orig( + scheduled_requests, + model_outputs, + num_context_logits_prefix_sum, + resource_manager, + ) + assert flashinfer_keys_seen + return res + + patch_ctx.setattr(sampler, "sample_async", _sample_async) + + @dataclass(frozen=True, kw_only=True) + class _TorchUtilsSamplingParams: + """Variant of UtilsSamplingParams which stores torch.Tensor, to avoid device syncs. + + Used by test_samples. + """ + + temperature: Optional[torch.Tensor] + top_p: Optional[torch.Tensor] + top_k: Optional[torch.Tensor] + + @dataclass(frozen=True, kw_only=True) + class _MockSamplingLogEntry: + probs: torch.Tensor + sampling_params: "TestBatchedSampling._TorchUtilsSamplingParams" + + @staticmethod + def _instrument_sampling_backend( + patch_ctx: pytest.MonkeyPatch, + *, + sampler: TorchSampler, + ) -> list["TestBatchedSampling._MockSamplingLogEntry"]: + """Setup interception of sampling routines. + + This patches the sampling backend. The added instrumentation records observed + sampling parameters and input probs in the returned log. Instead of tokens, the + patched sampling routines return indices into the log, permitting to retrieve the + captured sampling inputs. + + Used by test_samples. + """ + mock_sampling_log: list[TestBatchedSampling._MockSamplingLogEntry] = [] + + def _mock_flashinfer_top_k_top_p( + logits: torch.Tensor, + *, + top_k: torch.Tensor, + top_p: torch.Tensor, + filter_apply_order: str, + deterministic: bool, + check_nan: bool, + generator: torch.Generator, + ) -> torch.Tensor: + assert filter_apply_order == "top_k_first" + assert deterministic + assert not check_nan, "check_nan syncs" + assert generator is sampler.get_generator(logits.device) + nonlocal mock_sampling_log + new_entries = [ + TestBatchedSampling._MockSamplingLogEntry( + probs=torch.softmax(logits[row_idx], dim=-1), + sampling_params=TestBatchedSampling._TorchUtilsSamplingParams( + top_k=top_k[row_idx], + top_p=top_p[row_idx], + temperature=None, + ), + ) + for row_idx in range(logits.size(0)) + ] + mock_tokens = torch.arange( + len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) + ) + mock_sampling_log += new_entries + return mock_tokens + + patch_ctx.setattr( + flashinfer.sampling, + "top_k_top_p_sampling_from_logits", + _mock_flashinfer_top_k_top_p, + ) + + def _mock_flashinfer_from_logits( + logits: torch.Tensor, + *, + deterministic: bool, + check_nan: bool, + generator: torch.Generator, + ) -> torch.Tensor: + assert deterministic + assert not check_nan, "check_nan syncs" + assert generator is sampler.get_generator(logits.device) + nonlocal mock_sampling_log + new_entries = [ + TestBatchedSampling._MockSamplingLogEntry( + probs=torch.softmax(logits[row_idx], dim=-1), + sampling_params=TestBatchedSampling._TorchUtilsSamplingParams( + top_k=None, + top_p=None, + temperature=None, + ), + ) + for row_idx in range(logits.size(0)) + ] + mock_tokens = torch.arange( + len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) + ) + mock_sampling_log += new_entries + return mock_tokens + + patch_ctx.setattr(flashinfer.sampling, "sampling_from_logits", _mock_flashinfer_from_logits) + + def _mock_flashinfer_top_p( + probs: torch.Tensor, + *, + top_p: torch.Tensor, + deterministic: bool, + check_nan: bool, + generator: torch.Generator, + ) -> torch.Tensor: + assert deterministic + assert not check_nan, "check_nan syncs" + assert generator is sampler.get_generator(probs.device) + nonlocal mock_sampling_log + new_entries = [ + TestBatchedSampling._MockSamplingLogEntry( + probs=probs[row_idx], + sampling_params=TestBatchedSampling._TorchUtilsSamplingParams( + top_k=None, + top_p=top_p[row_idx], + temperature=None, + ), + ) + for row_idx in range(probs.size(0)) + ] + mock_tokens = torch.arange( + len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) + ) + mock_sampling_log += new_entries + return mock_tokens + + patch_ctx.setattr(flashinfer.sampling, "top_p_sampling_from_probs", _mock_flashinfer_top_p) + + def _mock_flashinfer_from_probs( + probs: torch.Tensor, + *, + deterministic: bool, + check_nan: bool, + generator: torch.Generator, + ) -> torch.Tensor: + assert deterministic + assert not check_nan, "check_nan syncs" + assert generator is sampler.get_generator(probs.device) + nonlocal mock_sampling_log + new_entries = [ + TestBatchedSampling._MockSamplingLogEntry( + probs=probs[row_idx], + sampling_params=TestBatchedSampling._TorchUtilsSamplingParams( + top_k=None, + top_p=None, + temperature=None, + ), + ) + for row_idx in range(probs.size(0)) + ] + mock_tokens = torch.arange( + len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) + ) + mock_sampling_log += new_entries + return mock_tokens + + patch_ctx.setattr(flashinfer.sampling, "sampling_from_probs", _mock_flashinfer_from_probs) + + def _mock_torch_multinomial( + probs: torch.Tensor, + num_samples: int, + generator: torch.Generator, + ) -> torch.Tensor: + assert generator is sampler.get_generator(probs.device) + assert num_samples == 1 + nonlocal mock_sampling_log + new_entries = [ + TestBatchedSampling._MockSamplingLogEntry( + probs=probs[row_idx], + sampling_params=TestBatchedSampling._TorchUtilsSamplingParams( + top_k=None, + top_p=None, + temperature=None, + ), + ) + for row_idx in range(probs.size(0)) + ] + mock_tokens = torch.arange( + len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) + ) + mock_sampling_log += new_entries + return mock_tokens.unsqueeze(-1) + + patch_ctx.setattr(torch, "multinomial", _mock_torch_multinomial) + + return mock_sampling_log + + @staticmethod + def _validate_intercepted_probs( + log_entry: "TestBatchedSampling._MockSamplingLogEntry", + *, + vocab_size: int, + expected_probs: torch.Tensor, + req_params: UtilsSamplingParams, + ): + """Validate sampling inputs captured by the code injected via _instrument_sampling_backend. + + Used by test_samples. + """ + # Tests rely on UUT handling temperature outside the sampling routines + assert log_entry.sampling_params.temperature is None + + req_has_top_p = ( + log_entry.sampling_params.top_p is not None + and log_entry.sampling_params.top_p.item() != 1 + ) + req_has_top_k = ( + log_entry.sampling_params.top_k is not None + and log_entry.sampling_params.top_k.item() != vocab_size + ) + if req_has_top_k: + assert req_params.top_k is not None + assert log_entry.sampling_params.top_k is not None + assert req_params.top_k == log_entry.sampling_params.top_k.item() + if req_has_top_p: + assert req_params.top_p is not None + assert log_entry.sampling_params.top_p is not None + assert np.allclose(req_params.top_p, log_entry.sampling_params.top_p.item()) + if req_has_top_k or req_has_top_p: + # for top-k and/or top-p _sampling_, probs contains only the top probs, + # whereas log_entry.probs contains all probs passed to the sampling code. + + # validate selection in 'probs' is consistent with log_entry.probs + log_entry_probs_selected = torch.where(expected_probs != 0, log_entry.probs.cpu(), 1) + log_entry_probs_masked = torch.where(expected_probs == 0, log_entry.probs.cpu(), 0) + assert torch.all( + log_entry_probs_masked.amax(dim=-1) <= log_entry_probs_selected.amin(dim=-1) + ) + + # validate non-zero probs + log_entry_probs_selected = torch.where(expected_probs != 0, log_entry.probs.cpu(), 0) + log_entry_probs_selected /= log_entry_probs_selected.sum(-1) + torch.testing.assert_close(log_entry_probs_selected, expected_probs) + else: + torch.testing.assert_close(log_entry.probs.cpu(), expected_probs) + + @staticmethod + def _validate_token_frequencies( + *, + test_token_counts: torch.Tensor, + test_expected_counts: torch.Tensor, + num_samples: int, + ): + """Check consistency of observed and expected token frequencies. + + Used by test_samples. + """ + # NB: G-test yields NaN if expected count is 0 + # -> check those entries separately and mask them + # (https://stats.stackexchange.com/a/668064) + # + test_token_counts_for_zero_prob = torch.where( + test_expected_counts != 0, 0, test_token_counts + ) + assert (test_token_counts_for_zero_prob == 0).all() + test_expected_counts_ma = np.ma.masked_array( + test_expected_counts.numpy(), + mask=(test_expected_counts.numpy() == 0), + ) + test_token_counts_ma = np.ma.masked_array( + test_token_counts.numpy(), + mask=test_expected_counts_ma.mask, + ) + + # FlashInfer normalization is numerically inaccurate enough to + # yield a tiny p-value in the test below, despite passing the + # test's normalization check. Most likely, this mainly + # affects the 'delta' distributions handled explicitly below. + assert np.allclose(test_expected_counts_ma.sum(axis=-1), num_samples) + test_expected_counts_ma /= test_expected_counts_ma.sum(axis=-1, keepdims=True) + test_expected_counts_ma *= num_samples + + # Skip entries with exact agreement. Needed, because + # 'power_divergence' generates NaN p-values otherwise. + mask = ~( + np.round(test_expected_counts_ma).astype(np.int64) + == test_token_counts_ma.astype(np.int64) + ).all(axis=-1) + test_expected_counts_ma = test_expected_counts_ma[mask] + test_token_counts_ma = test_token_counts_ma[mask] + + # Perform G-test (asymptotically approximated by Pearson's chi-square test) to + # check that sampled tokens are consistent with the expected probs. + test_result = power_divergence( + f_obs=test_token_counts_ma, + f_exp=test_expected_counts_ma, + axis=-1, + lambda_="log-likelihood", # = KL divergence + ) + if hasattr(test_result.pvalue, "mask"): + assert test_result.pvalue.mask + pvalue = test_result.pvalue.data + else: + pvalue = test_result.pvalue + if not np.all(pvalue > 0.1): # This can happen by "chance" (many test instances) + # Fail test if sampled data are highly unlikely + assert np.all(pvalue > 0.001) + prob_delta = np.abs(test_token_counts_ma - test_expected_counts_ma) / num_samples + # accept small prob differences + prob_delta = np.where(prob_delta > 5e-2, prob_delta, 0) # NB: this is rather liberal + # bound relative differences on remaining probs + prob_delta_rel = ( + np.ma.masked_array(num_samples * prob_delta, mask=test_expected_counts_ma.mask) + / test_expected_counts_ma.data + ) + assert prob_delta_rel.max() < 0.05 + @pytest.mark.parametrize( ( "use_flashinfer", @@ -1395,293 +1927,62 @@ def test_samples( logits / probs and return a pseudo-token identifying the capture result. Thus, the corresponding observed PMFs can be directly compared with the expected ones. """ - with torch.cuda.Stream(): + + @contextmanager + def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]: torch.manual_seed(42) # torch.testing.make_tensor does not accept Generator - # Because max_draft_len can be zero and probs are not computed in this case, - # a separate sampler instance (with larger max_draft_len) is needed to - # compute probs in general. - draft_len_with_probs = max(3, max_draft_len) - sampler_with_probs = self._build_sampler( + # Compute sampling probabilities for the given sampling_params_list and + # model_outputs. These probs, the computation of which is validated by 'test_probs', + # are used to validate the batched sampling process later in this test. + mock_requests_with_probs = self._compute_probs( use_flashinfer=use_flashinfer, - max_draft_len=draft_len_with_probs, - seq_slot_assignment=seq_slot_assignment, - ) - mock_requests_with_probs = self._build_mock_requests( + model_outputs=model_outputs, sampling_params_list=sampling_params_list, - vocab_size=vocab_size, seq_slot_assignment=seq_slot_assignment, - # NB: with_draft_logits=True and non-zero draft len ensures that - # LlmRequest.py_target_probs is set. - with_draft_logits=True, - draft_lens=([draft_len_with_probs] * len(sampling_params_list)), - ) - # zero-pad logits to draft_len_with_probs - logits = model_outputs["logits"] - logits_offset = 0 - steps_with_probs = draft_len_with_probs + 1 - logits_with_probs = torch.zeros( - (steps_with_probs * len(mock_requests_with_probs.all_requests()), vocab_size), - dtype=logits.dtype, - device=logits.device, - ) - for req_idx, draft_len in enumerate(draft_lens): - steps = draft_len + 1 - logits_with_probs[ - (req_idx * steps_with_probs) : (req_idx * steps_with_probs + steps) - ] = logits[logits_offset : (logits_offset + steps)] - logits_offset += steps - model_outputs_with_probs = model_outputs.copy() - model_outputs_with_probs["logits"] = logits_with_probs - _ = self._sample( - sampler_with_probs, - scheduled_requests=mock_requests_with_probs, - model_outputs=model_outputs_with_probs, + vocab_size=vocab_size, + max_draft_len=max_draft_len, + draft_lens=draft_lens, ) - num_samples = 5000 if not bypass_sampling else 1 - - @dataclass(frozen=True, kw_only=True) - class MockSamplingLogEntry: - probs: torch.Tensor - sampling_params: UtilsSamplingParams + num_samples = 5000 if not (bypass_sampling or is_warmup) else 1 # filled when bypass_sampling=True - mock_sampling_log: list[MockSamplingLogEntry] = [] + mock_sampling_log: Optional[list[TestBatchedSampling._MockSamplingLogEntry]] = None with monkeypatch.context() as patch_ctx: - # FlashInfer sampling batches requests of the same kind (e.g. top-p) - # together even if they have different parameter values (e.g. probability thresholds). - # This variable tracks which request types have been encountered. - flashinfer_keys_seen = set() - - if use_flashinfer: - sample_grouped_strategies_orig = ( - sampler._grouped_sampler_cls.sample_grouped_strategies - ) - - def _sample_grouped_strategies( - group_key: FlashInferGroupedStrategySampler.STRATEGY_KEY_TYPE, - strategies: list[Strategy], - logits: torch.Tensor, - *, - group_logit_indices: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, - return_probs: bool, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - assert issubclass( - group_key, sampling_utils_flashinfer._StrategyImpls.StrategyImpl - ) - assert generator is sampler.get_generator(logits.device) - nonlocal flashinfer_keys_seen - assert group_key not in flashinfer_keys_seen - flashinfer_keys_seen.add(group_key) - return sample_grouped_strategies_orig( - group_key, - strategies, - logits, - group_logit_indices=group_logit_indices, - generator=generator, - return_probs=return_probs, - ) - - patch_ctx.setattr( - sampler._grouped_sampler_cls, - "sample_grouped_strategies", - _sample_grouped_strategies, - ) - - sample_async_orig = sampler.sample_async - - def _sample_async( - scheduled_requests: ScheduledRequests, - model_outputs: dict[str, torch.Tensor], - num_context_logits_prefix_sum: list[int], - resource_manager=None, - ): - nonlocal flashinfer_keys_seen - flashinfer_keys_seen.clear() - res = sample_async_orig( - scheduled_requests, - model_outputs, - num_context_logits_prefix_sum, - resource_manager, - ) - assert flashinfer_keys_seen - return res - - patch_ctx.setattr(sampler, "sample_async", _sample_async) - + self._inject_batching_check( + patch_ctx, sampler=sampler, use_flashinfer=use_flashinfer + ) if bypass_sampling: - - def _mock_flashinfer_top_k_top_p( - logits: torch.Tensor, - *, - top_k: torch.Tensor, - top_p: torch.Tensor, - filter_apply_order: str, - deterministic: bool, - check_nan: bool, - generator: torch.Generator, - ) -> torch.Tensor: - assert filter_apply_order == "top_k_first" - assert deterministic - assert not check_nan, "check_nan syncs" - assert generator is sampler.get_generator(logits.device) - nonlocal mock_sampling_log - new_entries = [ - MockSamplingLogEntry( - probs=torch.softmax(logits[row_idx], dim=-1), - sampling_params=UtilsSamplingParams( - top_k=cast(int, top_k[row_idx].item()), - top_p=top_p[row_idx].item(), - temperature=None, - ), - ) - for row_idx in range(logits.size(0)) - ] - mock_tokens = torch.arange( - len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) - ) - mock_sampling_log += new_entries - return mock_tokens - - patch_ctx.setattr( - flashinfer.sampling, - "top_k_top_p_sampling_from_logits", - _mock_flashinfer_top_k_top_p, + mock_sampling_log = self._instrument_sampling_backend( + patch_ctx, sampler=sampler ) - def _mock_flashinfer_from_logits( - logits: torch.Tensor, - *, - deterministic: bool, - check_nan: bool, - generator: torch.Generator, - ) -> torch.Tensor: - assert deterministic - assert not check_nan, "check_nan syncs" - assert generator is sampler.get_generator(logits.device) - nonlocal mock_sampling_log - new_entries = [ - MockSamplingLogEntry( - probs=torch.softmax(logits[row_idx], dim=-1), - sampling_params=UtilsSamplingParams( - top_k=None, - top_p=None, - temperature=None, - ), - ) - for row_idx in range(logits.size(0)) - ] - mock_tokens = torch.arange( - len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) - ) - mock_sampling_log += new_entries - return mock_tokens + @dataclass + class UutResult: + new_tokens_repeats: torch.Tensor - patch_ctx.setattr( - flashinfer.sampling, "sampling_from_logits", _mock_flashinfer_from_logits - ) + @dataclass + class UutResultWrapper: + result: Optional[UutResult] = None - def _mock_flashinfer_top_p( - probs: torch.Tensor, - *, - top_p: torch.Tensor, - deterministic: bool, - check_nan: bool, - generator: torch.Generator, - ) -> torch.Tensor: - assert deterministic - assert not check_nan, "check_nan syncs" - assert generator is sampler.get_generator(probs.device) - nonlocal mock_sampling_log - new_entries = [ - MockSamplingLogEntry( - probs=probs[row_idx], - sampling_params=UtilsSamplingParams( - top_k=None, - top_p=top_p[row_idx].item(), - temperature=None, - ), - ) - for row_idx in range(probs.size(0)) - ] - mock_tokens = torch.arange( - len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) - ) - mock_sampling_log += new_entries - return mock_tokens + res = UutResultWrapper() - patch_ctx.setattr( - flashinfer.sampling, "top_p_sampling_from_probs", _mock_flashinfer_top_p + def _uut(res=res): + new_tokens_repeats = self._sample( + sampler, + scheduled_requests=mock_requests, + model_outputs=model_outputs, + num_repeats=num_samples, + allow_sync=is_warmup, ) + res.result = UutResult(new_tokens_repeats=new_tokens_repeats) - def _mock_flashinfer_from_probs( - probs: torch.Tensor, - *, - deterministic: bool, - check_nan: bool, - generator: torch.Generator, - ) -> torch.Tensor: - assert deterministic - assert not check_nan, "check_nan syncs" - assert generator is sampler.get_generator(probs.device) - nonlocal mock_sampling_log - new_entries = [ - MockSamplingLogEntry( - probs=probs[row_idx], - sampling_params=UtilsSamplingParams( - top_k=None, - top_p=None, - temperature=None, - ), - ) - for row_idx in range(probs.size(0)) - ] - mock_tokens = torch.arange( - len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) - ) - mock_sampling_log += new_entries - return mock_tokens + yield _uut - patch_ctx.setattr( - flashinfer.sampling, "sampling_from_probs", _mock_flashinfer_from_probs - ) - - def _mock_torch_multinomial( - probs: torch.Tensor, - num_samples: int, - generator: torch.Generator, - ) -> torch.Tensor: - assert generator is sampler.get_generator(probs.device) - assert num_samples == 1 - nonlocal mock_sampling_log - new_entries = [ - MockSamplingLogEntry( - probs=probs[row_idx], - sampling_params=UtilsSamplingParams( - top_k=None, - top_p=None, - temperature=None, - ), - ) - for row_idx in range(probs.size(0)) - ] - mock_tokens = torch.arange( - len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) - ) - mock_sampling_log += new_entries - return mock_tokens.unsqueeze(-1) - - patch_ctx.setattr(torch, "multinomial", _mock_torch_multinomial) - - new_tokens_repeats = self._sample( - sampler, - scheduled_requests=mock_requests, - model_outputs=model_outputs, - num_repeats=num_samples, - ) + assert res.result is not None + new_tokens_repeats = res.result.new_tokens_repeats # remove 'beam' dimension assert new_tokens_repeats.size(-2) == 1 @@ -1710,6 +2011,7 @@ def _mock_torch_multinomial( token_counts = token_counts.to(dtype=torch.float32) assert (token_counts.sum(-1, keepdim=True) == num_samples).all() + logits = model_outputs["logits"] for req_idx, (req, req_with_probs, draft_len) in enumerate( zip( mock_requests.all_requests(), @@ -1746,124 +2048,28 @@ def _mock_torch_multinomial( torch.testing.assert_close(req.py_target_probs.cpu(), probs) if bypass_sampling: # fast path (mock sampling) + assert mock_sampling_log is not None for step_idx in range(draft_len + 1): log_idx = new_tokens_repeats[step_idx, req.py_seq_slot, 0] log_entry = mock_sampling_log[log_idx] req_params = _request_get_sampling_params(req) - - # Tests rely on UUT handling temperature outside the sampling routines - assert log_entry.sampling_params.temperature is None - - req_has_top_p = ( - log_entry.sampling_params.top_p is not None - and log_entry.sampling_params.top_p != 1 - ) - req_has_top_k = ( - log_entry.sampling_params.top_k is not None - and log_entry.sampling_params.top_k != vocab_size + expected_probs = probs[step_idx] + self._validate_intercepted_probs( + log_entry, + vocab_size=vocab_size, + expected_probs=expected_probs, + req_params=req_params, ) - if req_has_top_k: - assert req_params.top_k is not None - assert req_params.top_k == log_entry.sampling_params.top_k - if req_has_top_p: - assert req_params.top_p is not None - assert log_entry.sampling_params.top_p is not None - assert np.allclose(req_params.top_p, log_entry.sampling_params.top_p) - if req_has_top_k or req_has_top_p: - # for top-k and/or top-p _sampling_, probs contains only the top probs, - # whereas log_entry.probs contains all probs passed to the sampling code. - - # validate selection in 'probs' is consistent with log_entry.probs - log_entry_probs_selected = torch.where( - probs[step_idx] != 0, log_entry.probs.cpu(), 1 - ) - log_entry_probs_masked = torch.where( - probs[step_idx] == 0, log_entry.probs.cpu(), 0 - ) - assert torch.all( - log_entry_probs_masked.amax(dim=-1) - <= log_entry_probs_selected.amin(dim=-1) - ) - - # validate non-zero probs - log_entry_probs_selected = torch.where( - probs[step_idx] != 0, log_entry.probs.cpu(), 0 - ) - log_entry_probs_selected /= log_entry_probs_selected.sum(-1) - torch.testing.assert_close(log_entry_probs_selected, probs[step_idx]) - else: - torch.testing.assert_close(log_entry.probs.cpu(), probs[step_idx]) - - continue # no real samples to run tests below - - test_token_counts = token_counts[: (draft_len + 1), req.py_seq_slot] - test_expected_counts = num_samples * probs.cpu() - - # NB: G-test yields NaN if expected count is 0 - # -> check those entries separately and mask them - # (https://stats.stackexchange.com/a/668064) - # - test_token_counts_for_zero_prob = torch.where( - test_expected_counts != 0, 0, test_token_counts - ) - assert (test_token_counts_for_zero_prob == 0).all() - test_expected_counts_ma = np.ma.masked_array( - test_expected_counts.numpy(), - mask=(test_expected_counts.numpy() == 0), - ) - test_token_counts_ma = np.ma.masked_array( - test_token_counts.numpy(), - mask=test_expected_counts_ma.mask, - ) - - # FlashInfer normalization is numerically inaccurate enough to - # yield a tiny p-value in the test below, despite passing the - # test's normalization check. Most likely, this mainly - # affects the 'delta' distributions handled explicitly below. - assert np.allclose(test_expected_counts_ma.sum(axis=-1), num_samples) - test_expected_counts_ma /= test_expected_counts_ma.sum(axis=-1, keepdims=True) - test_expected_counts_ma *= num_samples - - # Skip entries with exact agreement. Needed, because - # 'power_divergence' generates NaN p-values otherwise. - mask = ~( - np.round(test_expected_counts_ma).astype(np.int64) - == test_token_counts_ma.astype(np.int64) - ).all(axis=-1) - test_expected_counts_ma = test_expected_counts_ma[mask] - test_token_counts_ma = test_token_counts_ma[mask] - - # Perform G-test (asymptotically approximated by Pearson's chi-square test) to - # check that sampled tokens are consistent with the expected probs. - test_result = power_divergence( - f_obs=test_token_counts_ma, - f_exp=test_expected_counts_ma, - axis=-1, - lambda_="log-likelihood", # = KL divergence - ) - if hasattr(test_result.pvalue, "mask"): - assert test_result.pvalue.mask - pvalue = test_result.pvalue.data else: - pvalue = test_result.pvalue - if not np.all(pvalue > 0.1): # This can happen by "chance" (many test instances) - # Fail test if sampled data are highly unlikely - assert np.all(pvalue > 0.001) - prob_delta = ( - np.abs(test_token_counts_ma - test_expected_counts_ma) / num_samples - ) - # accept small prob differences - prob_delta = np.where( - prob_delta > 5e-2, prob_delta, 0 - ) # NB: this is rather liberal - # bound relative differences on remaining probs - prob_delta_rel = ( - np.ma.masked_array( - num_samples * prob_delta, mask=test_expected_counts_ma.mask - ) - / test_expected_counts_ma.data + test_token_counts = token_counts[: (draft_len + 1), req.py_seq_slot] + test_expected_counts = num_samples * probs.cpu() + self._validate_token_frequencies( + test_token_counts=test_token_counts, + test_expected_counts=test_expected_counts, + num_samples=num_samples, ) - assert prob_delta_rel.max() < 0.05 + + _run_test_with_warmup(_uut_provider) @staticmethod def _build_seq_slot_assignments() -> list[tuple[list[int], int, str]]: @@ -1961,7 +2167,9 @@ def test_unbatch_sampling_results( validates that the sampling results are copied into the correct locations in the output buffers. """ - with torch.cuda.Stream(): + + @contextmanager + def _uut_provider(is_warmup: bool) -> Generator[Callable[[], None], None, None]: seq_slots, total_seq_slots = seq_slot_assignment seq_slots_tensor = torch.tensor(seq_slots, dtype=torch.int32) @@ -2005,15 +2213,30 @@ def test_unbatch_sampling_results( ) seq_slots_tensor_snapshot = seq_slots_tensor.clone() - # FIXME: Currently, _unbatch_sampling_results is not fully async (TRTLLM-9175) - # with assert_no_cuda_sync(sync_timeout_s=0.2): - new_tokens_host = sampler._unbatch_sampling_results( - batched_sampling_result=batched_sampling_result, - new_tokens_cuda=new_tokens_cuda, - req_num_steps=req_num_steps, - seq_slots=seq_slots_tensor, - ) + @dataclass + class UutResult: + new_tokens_host: torch.Tensor + + @dataclass + class UutResultWrapper: + result: Optional[UutResult] = None + + res = UutResultWrapper() + + def _uut(res=res): + new_tokens_host = sampler._unbatch_sampling_results( + batched_sampling_result=batched_sampling_result, + new_tokens_cuda=new_tokens_cuda, + req_num_steps=req_num_steps, + seq_slots=seq_slots_tensor, + ) + res.result = UutResult(new_tokens_host=new_tokens_host) + + yield _uut + torch.cuda.synchronize() + assert res.result is not None + new_tokens_host = res.result.new_tokens_host assert new_tokens_host.device == torch.device("cpu") # check for unwanted side effects @@ -2038,3 +2261,5 @@ def test_unbatch_sampling_results( torch.testing.assert_close(new_tokens_cuda[:steps, seq_slot, 0], req_tokens) torch.testing.assert_close(new_tokens_host[:steps, seq_slot, 0], req_tokens.cpu()) input_offset += steps + + _run_test_with_warmup(_uut_provider, max_sync_s=0.2) diff --git a/tests/unittest/utils/util.py b/tests/unittest/utils/util.py index eda687f70d3..a8475927f95 100644 --- a/tests/unittest/utils/util.py +++ b/tests/unittest/utils/util.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import faulthandler import math import os import time @@ -472,7 +473,7 @@ def check_accuracy(a, b, atol, rtol, percent): mpi_disabled(), reason="This test is skipped for Ray orchestrator.") -@dataclass +@dataclass(kw_only=True) class DeviceSleepCtl: _cancellation_requested: bool = False @@ -485,10 +486,12 @@ def cancel(self): @hostfunc -def device_sleep(duration_s: float, - *, - ctl: DeviceSleepCtl, - spin_s: float = 0.1): +def device_sleep( + duration_s: float, + *, + ctl: DeviceSleepCtl, + spin_s: float = 0.1, +): spin_iters = math.ceil(duration_s / spin_s) for _ in range(spin_iters): if ctl.cancellation_requested: @@ -498,21 +501,37 @@ def device_sleep(duration_s: float, @contextmanager def assert_no_cuda_sync( - sync_timeout_s: float = 5) -> Generator[None, None, None]: + sync_timeout_s: float = 5, ) -> Generator[None, None, None]: """Check that the function does not stream synchronize.""" + # NB: This implementation only assumes that the CUDA operations performed + # in the guarded scope use the currently selected CUDA stream. This + # should also cover custom Torch ops as well as non-Torch kernels. + # + # Python's faulthandler is used to provide tracebacks which can help + # pin-pointing synchronizing code. This is combined with PyTorch's + # less general (instrumentation based) tooling, which is expected + # to provide improved error reporting for those issues which it can + # detect. sleep_finished_event = torch.cuda.Event() scope_finished_event = torch.cuda.Event() torch.cuda.synchronize() sleep_ctl = DeviceSleepCtl() - device_sleep(sync_timeout_s, ctl=sleep_ctl) + faulthandler.dump_traceback_later(sync_timeout_s) + device_sleep(2 * sync_timeout_s, ctl=sleep_ctl) # cancelled below sleep_finished_event.record() - yield None + torch_debug_mode_orig = torch.cuda.get_sync_debug_mode() + torch.cuda.set_sync_debug_mode("error") + try: + yield None + finally: + torch.cuda.set_sync_debug_mode(torch_debug_mode_orig) scope_finished_event.record() assert not sleep_finished_event.query( ), """sync code should return quickly""" + faulthandler.cancel_dump_traceback_later() sleep_ctl.cancel() scope_finished_event.synchronize()