@@ -1996,7 +1996,8 @@ def eplb_step(self, is_dummy: bool = False, is_profile: bool = False) -> None:
19961996 # Should be called after attention metadata creation. This just pads
19971997 # the second ubatch slice out to the total number of tokens
19981998 # (num_tokens + padding)
1999- def pad_out_ubatch_slice (self , ubatch_slices : UBatchSlices , num_total_tokens : int ):
1999+ @staticmethod
2000+ def pad_out_ubatch_slice (ubatch_slices : UBatchSlices , num_total_tokens : int ):
20002001 padded_second_ubatch_slice = slice (
20012002 ubatch_slices [1 ].token_slice .start , num_total_tokens
20022003 )
@@ -2085,12 +2086,13 @@ def _preprocess(
20852086 dict [str , Any ],
20862087 ]:
20872088 num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
2089+ is_first_rank = get_pp_group ().is_first_rank
20882090
20892091 # _prepare_inputs may reorder the batch, so we must gather multi
20902092 # modal outputs after that to ensure the correct order
20912093 if (
20922094 self .supports_mm_inputs
2093- and get_pp_group (). is_first_rank
2095+ and is_first_rank
20942096 and not self .model_config .is_encoder_decoder
20952097 ):
20962098 # Run the multimodal encoder if any.
@@ -2115,7 +2117,7 @@ def _preprocess(
21152117 ** self ._init_model_kwargs (num_scheduled_tokens ),
21162118 ** self ._extract_mm_kwargs (scheduler_output ),
21172119 }
2118- elif self .enable_prompt_embeds and get_pp_group (). is_first_rank :
2120+ elif self .enable_prompt_embeds and is_first_rank :
21192121 # Get the input embeddings for the tokens that are not input embeds,
21202122 # then put them into the appropriate positions.
21212123 # TODO(qthequartermasterman): Since even when prompt embeds are
@@ -2155,7 +2157,7 @@ def _preprocess(
21552157 else :
21562158 positions = self .positions .gpu [:num_input_tokens ]
21572159
2158- if get_pp_group (). is_first_rank :
2160+ if is_first_rank :
21592161 intermediate_tensors = None
21602162 else :
21612163 intermediate_tensors = self .sync_and_slice_intermediate_tensors (
@@ -2186,38 +2188,37 @@ def _sample(
21862188 # Sample the next token and get logprobs if needed.
21872189 sampling_metadata = self .input_batch .sampling_metadata
21882190 if spec_decode_metadata is None :
2189- sampler_output = self .sampler (
2191+ return self .sampler (
21902192 logits = logits ,
21912193 sampling_metadata = sampling_metadata ,
21922194 )
2193- else :
2194- # When indexing with a tensor (bonus_logits_indices), PyTorch
2195- # creates a new tensor with separate storage from the original
2196- # logits tensor. This means any in-place operations on bonus_logits
2197- # won't affect the original logits tensor.
2198- assert logits is not None
2199- bonus_logits = logits [spec_decode_metadata .bonus_logits_indices ]
2200- sampler_output = self .sampler (
2201- logits = bonus_logits ,
2202- sampling_metadata = sampling_metadata ,
2203- predict_bonus_token = True ,
2204- )
2205- bonus_token_ids = sampler_output .sampled_token_ids
2206-
2207- # Just like `bonus_logits`, `target_logits` is a new tensor with
2208- # separate storage from the original `logits` tensor. Therefore,
2209- # it is safe to update `target_logits` in place.
2210- target_logits = logits [spec_decode_metadata .target_logits_indices ]
2211- output_token_ids = self .rejection_sampler (
2212- spec_decode_metadata ,
2213- None , # draft_probs
2214- target_logits ,
2215- bonus_token_ids ,
2216- sampling_metadata ,
2217- )
2218- sampler_output .sampled_token_ids = output_token_ids
2219- self ._update_states_after_model_execute (output_token_ids )
22202195
2196+ # When indexing with a tensor (bonus_logits_indices), PyTorch
2197+ # creates a new tensor with separate storage from the original
2198+ # logits tensor. This means any in-place operations on bonus_logits
2199+ # won't affect the original logits tensor.
2200+ assert logits is not None
2201+ bonus_logits = logits [spec_decode_metadata .bonus_logits_indices ]
2202+ sampler_output = self .sampler (
2203+ logits = bonus_logits ,
2204+ sampling_metadata = sampling_metadata ,
2205+ predict_bonus_token = True ,
2206+ )
2207+ bonus_token_ids = sampler_output .sampled_token_ids
2208+
2209+ # Just like `bonus_logits`, `target_logits` is a new tensor with
2210+ # separate storage from the original `logits` tensor. Therefore,
2211+ # it is safe to update `target_logits` in place.
2212+ target_logits = logits [spec_decode_metadata .target_logits_indices ]
2213+ output_token_ids = self .rejection_sampler (
2214+ spec_decode_metadata ,
2215+ None , # draft_probs
2216+ target_logits ,
2217+ bonus_token_ids ,
2218+ sampling_metadata ,
2219+ )
2220+ sampler_output .sampled_token_ids = output_token_ids
2221+ self ._update_states_after_model_execute (output_token_ids )
22212222 return sampler_output
22222223
22232224 def _bookkeeping_sync (
@@ -3741,7 +3742,7 @@ def freeze_gc():
37413742 decode_cudagraph_batch_sizes = [
37423743 x
37433744 for x in self .cudagraph_batch_sizes
3744- if x <= max_num_tokens and x >= self .uniform_decode_query_len
3745+ if max_num_tokens >= x >= self .uniform_decode_query_len
37453746 ]
37463747 compilation_cases_decode = list (reversed (decode_cudagraph_batch_sizes ))
37473748 self ._capture_cudagraphs (
0 commit comments