@@ -593,8 +593,8 @@ def _ubatch_split(
593593 if not self .parallel_config .enable_microbatching :
594594 return (None , 0 , None )
595595
596+ # Check preconditions for microbatching
596597 total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
597- num_reqs = self .input_batch .num_reqs
598598 should_attempt_ubatching = \
599599 self .parallel_config .enable_microbatching and \
600600 total_num_scheduled_tokens >= \
@@ -610,11 +610,13 @@ def _ubatch_split(
610610 if not should_ubatch :
611611 return (None , 0 , None )
612612
613- # This doesn't actually pad the ubatch slices. It just initialize the
614- # split point to the correct value so that padding can be applied
613+ # This doesn't actually pad the ubatch slices. It just initializes the
614+ # split point to the padded value so that padding can be applied
615615 # to the second ubatch in pad_out_ubatch_slice after attention
616616 # metadata creation
617- assert num_pad_tokens < total_num_scheduled_tokens , f"num_pad_tokens { num_pad_tokens } original_num_tokens { total_num_scheduled_tokens } "
617+ assert num_pad_tokens < total_num_scheduled_tokens ,\
618+ f"num_pad_tokens { num_pad_tokens } " \
619+ f"original_num_tokens { total_num_scheduled_tokens } "
618620 total_num_tokens_per_ubatch = (total_num_scheduled_tokens +
619621 num_pad_tokens ) // 2
620622 padded_first_ubatch_slice = slice (0 , total_num_tokens_per_ubatch )
@@ -2945,14 +2947,17 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
29452947 "decode" if uniform_decode else "mixed prefill-decode" ,
29462948 cudagraph_runtime_mode .name ))
29472949 enable_microbatching = self .parallel_config .enable_microbatching
2948- # We skip EPLB here since we don't want to record dummy metrics
2950+ # DBO Only supports running Full cudagraphs with uniform
2951+ # decode lengths
29492952 if enable_microbatching and uniform_decode :
29502953 for num_tokens in compilation_cases :
29512954 # If the number of tokens is greater than the microbatching
29522955 # threshold, don't generate a microbatched cudagraph
29532956 if (num_tokens
29542957 < self .parallel_config .microbatching_token_threshold ):
29552958 continue
2959+
2960+ # Warmup
29562961 for _ in range (
29572962 self .compilation_config .cudagraph_num_of_warmups ):
29582963 force_attention = (
@@ -2963,13 +2968,14 @@ def _capture_cudagraphs(self, compilation_cases: list[int],
29632968 uniform_decode = True ,
29642969 allow_microbatching = True ,
29652970 skip_eplb = True )
2966- # DBO Only supports running with Full cudagraphs with uniform
2967- # decode lengths
2971+
2972+ # Graph Capture
29682973 self ._dummy_run (num_tokens ,
29692974 cudagraph_runtime_mode = CUDAGraphMode .FULL ,
29702975 uniform_decode = True ,
29712976 allow_microbatching = True ,
29722977 skip_eplb = True )
2978+ # We skip EPLB here since we don't want to record dummy metrics
29732979 for num_tokens in compilation_cases :
29742980 for _ in range (self .compilation_config .cudagraph_num_of_warmups ):
29752981 # Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
0 commit comments