1+ from collections import defaultdict
12from typing import Any , Iterable , Optional , Union
23
34import numpy as np
4- import jax
55import ray
66from ray .util .accelerators import tpu
77
1111Params = Any
1212Prefix = Any
1313DecodeState = Any
14+ NpPrefix = Any
1415
1516
1617class PyTorchRayEngine (engine_api .Engine ):
@@ -28,11 +29,15 @@ def __init__(
2829 tokenizer_path : str ,
2930 context_length : int ,
3031 batch_size : int ,
32+ is_disaggregated : bool = False ,
33+ pod_slice_name : str = None ,
3134 ):
3235 self .engine_workers = engine_workers
3336 self .tokenizer_path = tokenizer_path
3437 self .context_length = context_length
3538 self .batch_size = batch_size
39+ self .is_disaggregated = is_disaggregated
40+ self .pod_slice_name = pod_slice_name
3641
3742 # pylint: disable-next=all
3843 def load_params (self ) -> Params :
@@ -64,17 +69,33 @@ def prefill(
6469 ) -> Prefix :
6570 all_outputs = []
6671 for worker in self .engine_workers :
67- output = worker .prefill_ray .remote (
72+ prefill_func = (
73+ worker .prefill_ray_disaggregation
74+ if self .is_disaggregated
75+ else worker .prefill_ray
76+ )
77+ output = prefill_func .remote (
6878 params = params ,
6979 existing_prefix = existing_prefix ,
7080 padded_tokens = padded_tokens ,
7181 true_length = true_length ,
7282 )
7383 all_outputs .append (output )
74- _ = ray .get (all_outputs )
84+ results = ray .get (all_outputs )
7585 # The prefill function does not return any values;
7686 # the worker itself manages and maintains the prefill states.
77- return None
87+ return results [0 ]
88+
89+ def transfer (self , np_prefix : NpPrefix ) -> Any :
90+ """Store prefill result into object store, then transfer to decode engine workers."""
91+ all_outputs = []
92+ np_prefix_ref = ray .put (np_prefix )
93+ for worker in self .engine_workers :
94+ output = worker .transfer .remote (np_prefix_ref )
95+ all_outputs .append (output )
96+ results = ray .get (all_outputs )
97+
98+ return results [0 ]
7899
79100 def insert (
80101 self ,
@@ -126,7 +147,8 @@ def max_prefill_length(self) -> int:
126147
127148 @property
128149 def colocated_cpus (self ) -> Union [list [engine_api .CpuDevices ], None ]:
129- return jax .devices ("cpu" )[0 ]
150+ # ray head doesn't load any parameters
151+ return None
130152
131153 def get_prefix_destination_sharding (self ) -> Prefix :
132154 "No implementation"
@@ -153,16 +175,22 @@ def create_pytorch_ray_engine(
153175 quantize_kv = False ,
154176 max_cache_length = 1024 ,
155177 sharding_config = None ,
156- ) -> PyTorchRayEngine :
178+ is_disaggregated : bool = False ,
179+ num_hosts : int = 0 ,
180+ decode_pod_slice_name : str = None ,
181+ ) -> Any :
157182
183+ # Return tuple as reponse: issues/107
158184 supported_models = ["llama-2" , "llama-3" , "gemma" ]
159185 if model_name not in supported_models :
160186 raise NotImplementedError (
161187 f"Model name should be one of{ ',' .join (supported_models )} "
162188 )
163189 ray .init (ignore_reinit_error = True )
164190 pod_name = tpu .get_current_pod_name ()
165- num_hosts = tpu .get_current_pod_worker_count ()
191+ num_hosts = (
192+ num_hosts if is_disaggregated else tpu .get_current_pod_worker_count ()
193+ )
166194 print (f"pod_name:{ pod_name } , number of host: { num_hosts } " )
167195 assert (
168196 pod_name is not None
@@ -192,10 +220,34 @@ def create_pytorch_ray_engine(
192220 sharding_config = sharding_config ,
193221 )
194222 engine_workers .append (engine_worker )
195- engine_master = PyTorchRayEngine (
196- engine_workers = engine_workers ,
223+
224+ if not is_disaggregated :
225+ return PyTorchRayEngine (
226+ engine_workers = engine_workers ,
227+ tokenizer_path = tokenizer_path ,
228+ context_length = context_length ,
229+ batch_size = batch_size ,
230+ )
231+
232+ workers_dict = defaultdict (list )
233+ for worker in engine_workers :
234+ pod_slice_name = ray .get (worker .pod_slice_name .remote ())
235+ workers_dict [pod_slice_name ].append (worker )
236+
237+ prefill_engine = PyTorchRayEngine (
238+ engine_workers = workers_dict [pod_name ],
239+ tokenizer_path = tokenizer_path ,
240+ context_length = context_length ,
241+ batch_size = batch_size ,
242+ is_disaggregated = is_disaggregated ,
243+ pod_slice_name = pod_name ,
244+ )
245+ decode_engine = PyTorchRayEngine (
246+ engine_workers = workers_dict [decode_pod_slice_name ],
197247 tokenizer_path = tokenizer_path ,
198248 context_length = context_length ,
199249 batch_size = batch_size ,
250+ is_disaggregated = is_disaggregated ,
251+ pod_slice_name = decode_pod_slice_name ,
200252 )
201- return engine_master
253+ return ( prefill_engine , decode_engine )
0 commit comments