@@ -461,7 +461,7 @@ def prefill_ray(
461461 existing_prefix : Optional [Prefix ] = None ,
462462 padded_tokens : PrefillInputs , # PrefillInputs[np.ndarray],
463463 true_length : int ,
464- ) -> None :
464+ ) -> tuple [ Prefix , engine_api . ResultTokens ] :
465465 """Do prefill in ray worker"""
466466 logits , updated_caches = self .prefill (
467467 params = params ,
@@ -476,7 +476,25 @@ def prefill_ray(
476476 prefix = Prefix (token , updated_caches , true_length )
477477 self .prefix_queue .put (prefix , block = False )
478478
479- return token
479+ token_out = jnp .reshape (token , (1 , 1 ))
480+ data = jnp .concatenate (
481+ [
482+ token_out , # First token
483+ jnp .ones_like (token_out ), # validity of first token
484+ jnp .zeros ((1 , 1 ), dtype = jnp .int32 ), # length = 0
485+ ],
486+ axis = - 1 ,
487+ )
488+ length = token_out .shape [1 ]
489+ result = engine_api .ResultTokens (
490+ data = data ,
491+ tokens_idx = (0 , length ),
492+ valid_idx = (length , 2 * length ),
493+ length_idx = (2 * length , 2 * length + 1 ),
494+ samples_per_slot = 1 ,
495+ )
496+
497+ return prefix , result
480498
481499 def _convert_to_np_caches (
482500 self , caches : List [Tuple [jax .Array , jax .Array ]]
@@ -495,7 +513,7 @@ def prefill_ray_disaggregation(
495513 existing_prefix : Optional [Prefix ] = None ,
496514 padded_tokens : PrefillInputs , # PrefillInputs[np.ndarray],
497515 true_length : int ,
498- ) -> Any :
516+ ) -> tuple [ NpPrefix , engine_api . ResultTokens ] :
499517 """Do prefill in ray worker"""
500518 logits , updated_caches = self .prefill (
501519 params = params ,
@@ -513,7 +531,25 @@ def prefill_ray_disaggregation(
513531 np_update_caches = self ._convert_to_np_caches (updated_caches )
514532 np_prefix = NpPrefix (token , np_update_caches , true_length )
515533
516- return np_prefix
534+ token_out = jnp .reshape (token , (1 , 1 ))
535+ data = jnp .concatenate (
536+ [
537+ token_out , # First token
538+ jnp .ones_like (token_out ), # validity of first token
539+ jnp .zeros ((1 , 1 ), dtype = jnp .int32 ), # length = 0
540+ ],
541+ axis = - 1 ,
542+ )
543+ length = token_out .shape [1 ]
544+ result = engine_api .ResultTokens (
545+ data = data ,
546+ tokens_idx = (0 , length ),
547+ valid_idx = (length , 2 * length ),
548+ length_idx = (2 * length , 2 * length + 1 ),
549+ samples_per_slot = 1 ,
550+ )
551+
552+ return np_prefix , result
517553
518554 def transfer (self , np_prefix : NpPrefix ) -> Any :
519555 """Transfer prefill result from object store to HBM"""
0 commit comments