Skip to content

Commit c360158

Browse files
authored
Ray Disaggregated Serving MVP (#106)
* Ray disaggregated MVP support * add jax cpu * add comments * format * format * assign call prefill in one line * refactor prefill in ray engine * format * clean up ray prefill * remove duplicated flax installation * add tuple as todo
1 parent 57e6fcf commit c360158

File tree

4 files changed

+327
-10
lines changed

4 files changed

+327
-10
lines changed

install_everything.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@ pip show libtpu-nightly && pip uninstall -y libtpu-nightly
1919
pip show tensorflow && pip uninstall -y tensorflow
2020
pip show ray && pip uninstall -y ray
2121
pip show flax && pip uninstall -y flax
22+
pip show keras && pip uninstall -y keras
23+
pip show tensorboard && pip uninstall -y tensorboard
24+
pip show tensorflow-text && pip uninstall -y tensorflow-text
25+
pip show torch_xla2 && pip uninstall -y torch_xla2
2226

2327
pip install flax==0.8.3
2428
pip install jax[tpu]==0.4.28 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
29+
pip install tensorflow-text
30+
pip install tensorflow
31+
2532
pip install ray[default]==2.22.0
2633
# torch cpu
2734
pip install torch==2.2.1+cpu --index-url https://download.pytorch.org/whl/cpu

jetstream_pt/ray_engine.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from collections import defaultdict
12
from typing import Any, Iterable, Optional, Union
23

34
import numpy as np
4-
import jax
55
import ray
66
from ray.util.accelerators import tpu
77

@@ -11,6 +11,7 @@
1111
Params = Any
1212
Prefix = Any
1313
DecodeState = Any
14+
NpPrefix = Any
1415

1516

1617
class PyTorchRayEngine(engine_api.Engine):
@@ -28,11 +29,15 @@ def __init__(
2829
tokenizer_path: str,
2930
context_length: int,
3031
batch_size: int,
32+
is_disaggregated: bool = False,
33+
pod_slice_name: str = None,
3134
):
3235
self.engine_workers = engine_workers
3336
self.tokenizer_path = tokenizer_path
3437
self.context_length = context_length
3538
self.batch_size = batch_size
39+
self.is_disaggregated = is_disaggregated
40+
self.pod_slice_name = pod_slice_name
3641

3742
# pylint: disable-next=all
3843
def load_params(self) -> Params:
@@ -64,17 +69,33 @@ def prefill(
6469
) -> Prefix:
6570
all_outputs = []
6671
for worker in self.engine_workers:
67-
output = worker.prefill_ray.remote(
72+
prefill_func = (
73+
worker.prefill_ray_disaggregation
74+
if self.is_disaggregated
75+
else worker.prefill_ray
76+
)
77+
output = prefill_func.remote(
6878
params=params,
6979
existing_prefix=existing_prefix,
7080
padded_tokens=padded_tokens,
7181
true_length=true_length,
7282
)
7383
all_outputs.append(output)
74-
_ = ray.get(all_outputs)
84+
results = ray.get(all_outputs)
7585
# The prefill function does not return any values;
7686
# the worker itself manages and maintains the prefill states.
77-
return None
87+
return results[0]
88+
89+
def transfer(self, np_prefix: NpPrefix) -> Any:
90+
"""Store prefill result into object store, then transfer to decode engine workers."""
91+
all_outputs = []
92+
np_prefix_ref = ray.put(np_prefix)
93+
for worker in self.engine_workers:
94+
output = worker.transfer.remote(np_prefix_ref)
95+
all_outputs.append(output)
96+
results = ray.get(all_outputs)
97+
98+
return results[0]
7899

79100
def insert(
80101
self,
@@ -126,7 +147,8 @@ def max_prefill_length(self) -> int:
126147

127148
@property
128149
def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]:
129-
return jax.devices("cpu")[0]
150+
# ray head doesn't load any parameters
151+
return None
130152

131153
def get_prefix_destination_sharding(self) -> Prefix:
132154
"No implementation"
@@ -153,16 +175,22 @@ def create_pytorch_ray_engine(
153175
quantize_kv=False,
154176
max_cache_length=1024,
155177
sharding_config=None,
156-
) -> PyTorchRayEngine:
178+
is_disaggregated: bool = False,
179+
num_hosts: int = 0,
180+
decode_pod_slice_name: str = None,
181+
) -> Any:
157182

183+
# Return tuple as reponse: issues/107
158184
supported_models = ["llama-2", "llama-3", "gemma"]
159185
if model_name not in supported_models:
160186
raise NotImplementedError(
161187
f"Model name should be one of{','.join(supported_models)}"
162188
)
163189
ray.init(ignore_reinit_error=True)
164190
pod_name = tpu.get_current_pod_name()
165-
num_hosts = tpu.get_current_pod_worker_count()
191+
num_hosts = (
192+
num_hosts if is_disaggregated else tpu.get_current_pod_worker_count()
193+
)
166194
print(f"pod_name:{pod_name}, number of host: {num_hosts}")
167195
assert (
168196
pod_name is not None
@@ -192,10 +220,34 @@ def create_pytorch_ray_engine(
192220
sharding_config=sharding_config,
193221
)
194222
engine_workers.append(engine_worker)
195-
engine_master = PyTorchRayEngine(
196-
engine_workers=engine_workers,
223+
224+
if not is_disaggregated:
225+
return PyTorchRayEngine(
226+
engine_workers=engine_workers,
227+
tokenizer_path=tokenizer_path,
228+
context_length=context_length,
229+
batch_size=batch_size,
230+
)
231+
232+
workers_dict = defaultdict(list)
233+
for worker in engine_workers:
234+
pod_slice_name = ray.get(worker.pod_slice_name.remote())
235+
workers_dict[pod_slice_name].append(worker)
236+
237+
prefill_engine = PyTorchRayEngine(
238+
engine_workers=workers_dict[pod_name],
239+
tokenizer_path=tokenizer_path,
240+
context_length=context_length,
241+
batch_size=batch_size,
242+
is_disaggregated=is_disaggregated,
243+
pod_slice_name=pod_name,
244+
)
245+
decode_engine = PyTorchRayEngine(
246+
engine_workers=workers_dict[decode_pod_slice_name],
197247
tokenizer_path=tokenizer_path,
198248
context_length=context_length,
199249
batch_size=batch_size,
250+
is_disaggregated=is_disaggregated,
251+
pod_slice_name=decode_pod_slice_name,
200252
)
201-
return engine_master
253+
return (prefill_engine, decode_engine)

jetstream_pt/ray_worker.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jax
2424
import numpy as np
2525
import ray
26+
from ray.util.accelerators import tpu
2627
import safetensors
2728
import torch
2829
import torch_xla2
@@ -57,6 +58,14 @@ class Prefix:
5758
seq_len: int # true seqlen front pad
5859

5960

61+
@struct.dataclass
62+
# pylint: disable-next=all
63+
class NpPrefix:
64+
token: jax.Array # [1, seqlen]
65+
caches: List[Tuple[jax.Array, jax.Array]]
66+
seq_len: int # true seqlen front pad
67+
68+
6069
@struct.dataclass
6170
# pylint: disable-next=all
6271
class DecodeState:
@@ -461,6 +470,50 @@ def prefill_ray(
461470

462471
return token
463472

473+
def _convert_to_np_caches(
474+
self, caches: List[Tuple[jax.Array, jax.Array]]
475+
) -> List[Tuple[np.ndarray, np.ndarray]]:
476+
return [(np.asarray(tup[0]), np.asarray(tup[1])) for tup in caches]
477+
478+
def _convert_to_jax_caches(
479+
self, np_caches: List[Tuple[np.ndarray, np.ndarray]]
480+
) -> List[Tuple[jax.Array, jax.Array]]:
481+
return [(jnp.asarray(tup[0]), jnp.asarray(tup[1])) for tup in np_caches]
482+
483+
def prefill_ray_disaggregation(
484+
self,
485+
*,
486+
params: Any, # Weights
487+
existing_prefix: Optional[Prefix] = None,
488+
padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray],
489+
true_length: int,
490+
) -> Any:
491+
"""Do prefill in ray worker"""
492+
logits, updated_caches = self.prefill(
493+
params=params,
494+
existing_prefix=existing_prefix,
495+
padded_tokens=padded_tokens,
496+
true_length=true_length,
497+
)
498+
if len(logits.shape) == 3: # b, seqlen, num words
499+
logits = logits[0]
500+
501+
token = np.argmax(logits[true_length - 1])
502+
updated_caches = multihost_utils.process_allgather(
503+
updated_caches, tiled=True
504+
)
505+
np_update_caches = self._convert_to_np_caches(updated_caches)
506+
np_prefix = NpPrefix(token, np_update_caches, true_length)
507+
508+
return np_prefix
509+
510+
def transfer(self, np_prefix: NpPrefix) -> Any:
511+
"""Transfer prefill result from object store to HBM"""
512+
updated_caches = self._convert_to_jax_caches(np_prefix.caches)
513+
prefix = Prefix(np_prefix.token, updated_caches, np_prefix.seq_len)
514+
self.prefix_queue.put(prefix, block=False)
515+
return True
516+
464517
def shrink_prefix(
465518
self,
466519
prefix: Prefix,
@@ -884,3 +937,7 @@ def max_decode_length(self) -> int:
884937
def mesh(self):
885938
"""return mesh"""
886939
return None
940+
941+
def pod_slice_name(self):
942+
"""pod slice name"""
943+
return tpu.get_current_pod_name()

0 commit comments

Comments
 (0)