@@ -836,8 +836,9 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
836836 if self .input_batch .prev_sampled_token_ids is None :
837837 # Normal scheduling case
838838 self .input_ids .copy_to_gpu (total_num_scheduled_tokens )
839- self .inputs_embeds .copy_to_gpu (total_num_scheduled_tokens )
840- self .is_token_ids .copy_to_gpu (total_num_scheduled_tokens )
839+ if self .enable_prompt_embeds :
840+ self .inputs_embeds .copy_to_gpu (total_num_scheduled_tokens )
841+ self .is_token_ids .copy_to_gpu (total_num_scheduled_tokens )
841842 return
842843
843844 # Async scheduling case, where some decode requests from the previous
@@ -863,8 +864,9 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
863864 # If not all requests are decodes from the last iteration,
864865 # We need to copy the input_ids_cpu to the GPU first.
865866 self .input_ids .copy_to_gpu (total_num_scheduled_tokens )
866- self .inputs_embeds .copy_to_gpu (total_num_scheduled_tokens )
867- self .is_token_ids .copy_to_gpu (total_num_scheduled_tokens )
867+ if self .enable_prompt_embeds :
868+ self .inputs_embeds .copy_to_gpu (total_num_scheduled_tokens )
869+ self .is_token_ids .copy_to_gpu (total_num_scheduled_tokens )
868870 if num_commmon_tokens == 0 :
869871 # No requests in common with the previous iteration
870872 # So input_ids_cpu will have all the input ids.
@@ -878,7 +880,8 @@ def _prepare_input_ids(self, total_num_scheduled_tokens: int,
878880 self .input_batch .prev_sampled_token_ids [:num_commmon_tokens ,
879881 0 ],
880882 non_blocking = True )
881- self .is_token_ids .gpu [:num_commmon_tokens ] = True
883+ if self .enable_prompt_embeds :
884+ self .is_token_ids .gpu [:num_commmon_tokens ] = True
882885 return
883886 # Upload the index tensors asynchronously
884887 # so the scatter can be non-blocking.
@@ -978,12 +981,13 @@ def _prepare_inputs(
978981 0 ,
979982 token_indices_tensor ,
980983 out = self .input_ids .cpu [:total_num_scheduled_tokens ])
981- is_token_ids = self .input_batch .is_token_ids .flatten ()
982- torch .index_select (
983- is_token_ids ,
984- 0 ,
985- token_indices_tensor ,
986- out = self .is_token_ids .cpu [:total_num_scheduled_tokens ])
984+ if self .enable_prompt_embeds :
985+ is_token_ids = self .input_batch .is_token_ids .flatten ()
986+ torch .index_select (
987+ is_token_ids ,
988+ 0 ,
989+ token_indices_tensor ,
990+ out = self .is_token_ids .cpu [:total_num_scheduled_tokens ])
987991
988992 # Because we did not pre-allocate a massive prompt_embeds CPU tensor on
989993 # the InputBatch, we need to fill in the prompt embeds into the expected
0 commit comments