2323from vllm .compilation .counter import compilation_counter
2424from vllm .compilation .cuda_graph import CUDAGraphWrapper
2525from vllm .compilation .monitor import set_cudagraph_capturing_enabled
26+ from vllm .compilation .ubatch_utils import UbatchSlice , UBatchSlices
2627from vllm .compilation .ubatch_wrapper import UBatchWrapper
27- from vllm .compilation .ubatch_utils import (UbatchSlice , UBatchSlices )
2828from vllm .config import (CompilationLevel , CUDAGraphMode , VllmConfig ,
2929 get_layers_from_vllm_config , update_config )
3030from vllm .distributed .eplb .eplb_state import EplbState
6060from vllm .v1 .attention .backends .flash_attn import FlashAttentionMetadata
6161from vllm .v1 .attention .backends .utils import (
6262 AttentionCGSupport , AttentionMetadataBuilder , CommonAttentionMetadata ,
63- UbatchSlice , create_fast_prefill_custom_backend ,
63+ create_fast_prefill_custom_backend ,
6464 reorder_batch_to_split_decodes_and_prefills , split_attn_metadata )
6565from vllm .v1 .cudagraph_dispatcher import CudagraphDispatcher
6666from vllm .v1 .kv_cache_interface import (AttentionSpec ,
@@ -605,8 +605,8 @@ def _ubatch_split(
605605 num_pad_tokens = 0
606606 num_tokens_after_padding = None
607607 (should_ubatch , num_pad_tokens ,
608- num_tokens_after_padding ) = self .get_dp_padding_ubatch (total_num_scheduled_tokens ,
609- should_attempt_ubatching )
608+ num_tokens_after_padding ) = self .get_dp_padding_ubatch (
609+ total_num_scheduled_tokens , should_attempt_ubatching )
610610 if not should_ubatch :
611611 return (None , 0 , None )
612612
@@ -1570,16 +1570,16 @@ def get_dp_padding_ubatch(
15701570 should_ubatch = False
15711571
15721572 # Note that we compute the number of padded tokens per ubatch
1573- (should_ubatch ,
1574- num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (should_ubatch ,
1575- num_tokens_unpadded // 2 , num_tokens_per_ubatch )
1573+ (should_ubatch ,
1574+ num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (
1575+ should_ubatch , num_tokens_unpadded // 2 , num_tokens_per_ubatch )
15761576 if not should_ubatch :
15771577 assert num_tokens_across_dp is None
15781578 return should_ubatch , 0 , num_tokens_across_dp
15791579
15801580 assert num_tokens_across_dp is not None
15811581
1582- max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp ).item ()
1582+ max_tokens_across_dp_cpu = int ( torch .max (num_tokens_across_dp ).item () )
15831583 num_tokens_after_padding = torch .tensor ([max_tokens_across_dp_cpu ] *
15841584 dp_size ,
15851585 device = "cpu" ,
@@ -1594,22 +1594,23 @@ def get_dp_padding_ubatch(
15941594 # the second ubatch slice out to the total number of tokens
15951595 # (num_tokens + padding)
15961596 def pad_out_ubatch_slice (self , ubatch_slices : UBatchSlices ,
1597- num_total_tokens : int ):
1597+ num_total_tokens : int ):
15981598 padded_second_ubatch_slice = slice (ubatch_slices [1 ].token_slice .start ,
15991599 num_total_tokens )
16001600 ubatch_slices [1 ] = UbatchSlice (padded_second_ubatch_slice ,
16011601 padded_second_ubatch_slice )
16021602
1603- def should_ubatch_with_num_tokens (self , should_ubatch : bool , orig_num_tokens_per_ubatch : int ,
1604- padded_num_tokens_per_ubatch : int ,
1605- ) -> tuple [bool , Optional [torch .Tensor ]]:
1603+ def should_ubatch_with_num_tokens (
1604+ self ,
1605+ should_ubatch : bool ,
1606+ orig_num_tokens_per_ubatch : int ,
1607+ padded_num_tokens_per_ubatch : int ,
1608+ ) -> tuple [bool , Optional [torch .Tensor ]]:
16061609 dp_size = self .vllm_config .parallel_config .data_parallel_size
16071610 dp_rank = self .vllm_config .parallel_config .data_parallel_rank
1608- return DPMetadata .should_ubatch_across_dp (should_ubatch ,
1609- orig_num_tokens_per_ubatch ,
1610- padded_num_tokens_per_ubatch ,
1611- dp_size ,
1612- dp_rank )
1611+ return DPMetadata .should_ubatch_across_dp (
1612+ should_ubatch , orig_num_tokens_per_ubatch ,
1613+ padded_num_tokens_per_ubatch , dp_size , dp_rank )
16131614
16141615 def _pool (
16151616 self ,
@@ -2426,23 +2427,26 @@ def _dummy_run(
24262427 remove_lora: If False, dummy LoRAs are not destroyed after the run
24272428 """
24282429 ubatch_enabled = self .parallel_config .enable_microbatching
2430+ num_tokens_across_dp = None
2431+ num_pad = 0
24292432 should_ubatch = False
24302433 if ubatch_enabled :
24312434 should_ubatch = num_tokens >= \
24322435 self .parallel_config .microbatching_token_threshold and \
24332436 allow_microbatching
2434- should_ubatch , _ = self .should_ubatch_with_num_tokens (
2435- should_ubatch ,
2436- num_tokens // 2 ,
2437- num_tokens // 2 ,
2438- )
2437+
2438+ (should_ubatch , num_pad ,
2439+ num_tokens_across_dp ) = self .get_dp_padding_ubatch (
2440+ num_tokens , should_ubatch )
2441+
2442+ # Currently the dummy run should only be ubatching during
2443+ # cuda graph capture, meaning all DP ranks should already
2444+ # have the same batch size
2445+ assert num_pad == 0
24392446 assert cudagraph_runtime_mode in {
24402447 CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL
24412448 }
24422449
2443- # Padding for DP
2444- num_tokens_across_dp = None
2445- num_pad = 0
24462450 if not should_ubatch :
24472451 num_pad , num_tokens_across_dp = self .get_dp_padding (num_tokens )
24482452 num_tokens += num_pad
@@ -2497,12 +2501,12 @@ def _dummy_run(
24972501 # We only support decode-only cudagraphs
24982502 assert num_reqs == num_tokens
24992503 assert num_tokens % 2 == 0
2500- num_tokens_per_ubatch = num_tokens // 2
2501- dp_size = self .vllm_config .parallel_config .data_parallel_size
2502- num_tokens_across_dp = torch .tensor ([num_tokens_per_ubatch ] *
2503- dp_size ,
2504- device = "cpu" ,
2505- dtype = torch .int32 )
2504+ # num_tokens_per_ubatch = num_tokens // 2
2505+ # dp_size = self.vllm_config.parallel_config.data_parallel_size
2506+ # num_tokens_across_dp = torch.tensor([num_tokens_per_ubatch] *
2507+ # dp_size,
2508+ # device="cpu",
2509+ # dtype=torch.int32)
25062510 ubatch_slices = [
25072511 UbatchSlice (slice (0 , num_reqs // 2 ), slice (0 ,
25082512 num_tokens // 2 )),
0 commit comments