Skip to content

Commit a58051d

Browse files
authored
Add ray multiple host support (#63)
* Add ray multiple host support * add dependencies * add dependencies * add assertion check on pod_name and num_hosts * Update ray engine and worker * update interactive * update ray worker * add comments * update comments
1 parent 2b1a527 commit a58051d

File tree

5 files changed

+1218
-2
lines changed

5 files changed

+1218
-2
lines changed

.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[MESSAGES CONTROL]
2-
disable=C0114,R0801,E1102,W0613
2+
disable=C0114,R0801,E1102,W0613,R1711

install_everything.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ pip3 show libtpu-nightly && pip3 uninstall -y libtpu-nightly
2222
pip3 install pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
2323
# torch cpu
2424
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
25-
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage safetensors colorama coverage
25+
pip3 install tensorflow flatbuffers absl-py flax sentencepiece seqio google-cloud-storage
26+
pip3 install safetensors colorama coverage ray[default] humanize
2627

2728
mkdir -p deps
2829
pushd deps

jetstream_pt/ray_engine.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from typing import Any, Iterable, Optional, Union
2+
3+
import numpy as np
4+
import jax
5+
import ray
6+
from ray.util.accelerators import tpu
7+
8+
from jetstream.engine import engine_api, tokenizer_pb2
9+
from jetstream_pt.ray_worker import PyTorchRayWorker
10+
11+
Params = Any
12+
Prefix = Any
13+
DecodeState = Any
14+
15+
16+
class PyTorchRayEngine(engine_api.Engine):
17+
"""Ray PyTorch Engine Implementation for Multi-Host Inference Serving.
18+
Key Features:
19+
1. Manages all Ray workers.
20+
2. Initializes model parameters for each Ray worker.
21+
3. Routes incoming inference requests to Ray workers.
22+
4. Collects token responses from the Ray workers.
23+
"""
24+
25+
def __init__(
26+
self,
27+
engine_workers: Iterable[PyTorchRayWorker],
28+
tokenizer_path: str,
29+
context_length: int,
30+
batch_size: int,
31+
):
32+
self.engine_workers = engine_workers
33+
self.tokenizer_path = tokenizer_path
34+
self.context_length = context_length
35+
self.batch_size = batch_size
36+
37+
# pylint: disable-next=all
38+
def load_params(self) -> Params:
39+
all_outputs = []
40+
for worker in self.engine_workers:
41+
output = worker.load_params_ray.remote()
42+
all_outputs.append(output)
43+
_ = ray.get(all_outputs)
44+
return None
45+
46+
# pylint: disable-next=all
47+
def init_decode_state(
48+
self,
49+
) -> DecodeState:
50+
all_outputs = []
51+
for worker in self.engine_workers:
52+
output = worker.init_decode_state_ray.remote()
53+
all_outputs.append(output)
54+
_ = ray.get(all_outputs)
55+
return None
56+
57+
def prefill(
58+
self,
59+
*,
60+
params: Any, # Weights
61+
existing_prefix: Optional[Prefix] = None,
62+
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
63+
true_length: int,
64+
) -> Prefix:
65+
all_outputs = []
66+
for worker in self.engine_workers:
67+
output = worker.prefill_ray.remote(
68+
params=params,
69+
existing_prefix=existing_prefix,
70+
padded_tokens=padded_tokens,
71+
true_length=true_length,
72+
)
73+
all_outputs.append(output)
74+
_ = ray.get(all_outputs)
75+
# The prefill function does not return any values;
76+
# the worker itself manages and maintains the prefill states.
77+
return None
78+
79+
def insert(
80+
self,
81+
prefix: Prefix,
82+
decode_state: DecodeState,
83+
slot: int,
84+
) -> DecodeState:
85+
all_outputs = []
86+
for worker in self.engine_workers:
87+
output = worker.insert_ray.remote(
88+
prefix=prefix, decode_state=decode_state, slot=slot
89+
)
90+
all_outputs.append(output)
91+
_ = ray.get(all_outputs)
92+
# The insert function does not return any values;
93+
# the worker itself manages and maintains the DecodeState.
94+
return None
95+
96+
def generate(
97+
self, params: Any, decode_state: DecodeState
98+
) -> tuple[None, engine_api.ResultTokens]:
99+
all_outputs = []
100+
for worker in self.engine_workers:
101+
output = worker.generate_ray.remote(
102+
params=params, decode_state=decode_state
103+
)
104+
all_outputs.append(output)
105+
# All workers performed an all_gather operation. Since the results are
106+
# identical across all workers, the result from worker 0 is returned.
107+
state, result_tokens = ray.get(all_outputs)[0]
108+
return state, result_tokens
109+
110+
# pylint: disable-next=all
111+
def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters:
112+
# pylint: disable-next=all
113+
return tokenizer_pb2.TokenizerParameters(path=self.tokenizer_path)
114+
115+
@property
116+
def max_concurrent_decodes(self) -> int:
117+
return self.batch_size
118+
119+
@property
120+
def samples_per_slot(self) -> int:
121+
return 1
122+
123+
@property
124+
def max_prefill_length(self) -> int:
125+
return self.context_length
126+
127+
@property
128+
def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]:
129+
return jax.devices("cpu")[0]
130+
131+
def get_prefix_destination_sharding(self) -> Prefix:
132+
"No implementation"
133+
return None
134+
135+
@property
136+
def mesh(self):
137+
"No implementation"
138+
return None
139+
140+
141+
# pylint: disable-next=all
142+
def create_pytorch_ray_engine(
143+
tokenizer_path: str,
144+
ckpt_path: Optional[str] = None,
145+
samples_per_slot: int = 1,
146+
bf16_enable: bool = False,
147+
param_size: str = "7b",
148+
context_length: int = 1024,
149+
batch_size: int = 1,
150+
max_decode_length: int = 4096,
151+
model_name="llama",
152+
quantize_weights=False,
153+
quantize_kv=False,
154+
max_cache_length=1024,
155+
) -> PyTorchRayEngine:
156+
157+
ray.init(ignore_reinit_error=True)
158+
pod_name = tpu.get_current_pod_name()
159+
num_hosts = tpu.get_current_pod_worker_count()
160+
print(f"pod_name:{pod_name}, number of host: {num_hosts}")
161+
assert (
162+
pod_name is not None
163+
), f"TPU pod name (current value:{pod_name}) can not be None"
164+
assert (
165+
num_hosts > 0
166+
), f"num_hosts (current value {num_hosts}) should be a positive number"
167+
# pylint: disable-next=all
168+
engine_worker_with_tpu_resource = PyTorchRayWorker.options(
169+
resources={"TPU": 4}
170+
)
171+
engine_workers = []
172+
for _ in range(num_hosts):
173+
engine_worker = engine_worker_with_tpu_resource.remote(
174+
tokenizer_path=tokenizer_path,
175+
ckpt_path=ckpt_path,
176+
samples_per_slot=samples_per_slot,
177+
bf16_enable=bf16_enable,
178+
param_size=param_size,
179+
context_length=context_length,
180+
batch_size=batch_size,
181+
max_decode_length=max_decode_length,
182+
model_name=model_name,
183+
quantize_weights=quantize_weights,
184+
quantize_kv=quantize_kv,
185+
max_cache_length=max_cache_length,
186+
)
187+
engine_workers.append(engine_worker)
188+
engine_master = PyTorchRayEngine(
189+
engine_workers=engine_workers,
190+
tokenizer_path=tokenizer_path,
191+
context_length=context_length,
192+
batch_size=batch_size,
193+
)
194+
return engine_master

0 commit comments

Comments
 (0)