Skip to content

Commit b234a57

Browse files
authored
Add run_server with ray for interleave serving (#109)
* Add run_server with ray * format
1 parent fac5c8e commit b234a57

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

run_server_with_ray.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Runs a pytorch server with ray."""
16+
import os
17+
import time
18+
from typing import Sequence
19+
from absl import app, flags
20+
21+
import jax
22+
from jetstream.core import server_lib
23+
from jetstream.core.config_lib import ServerConfig
24+
from jetstream_pt import ray_engine
25+
from jetstream_pt.config import FLAGS
26+
27+
flags.DEFINE_integer("port", 9000, "port to listen on")
28+
flags.DEFINE_integer("threads", 64, "number of worker threads in thread pool")
29+
flags.DEFINE_string(
30+
"config",
31+
"InterleavedCPUTestServer",
32+
"available servers",
33+
)
34+
flags.DEFINE_integer("prometheus_port", 0, "")
35+
flags.DEFINE_integer("tpu_chips", 16, "device tpu_chips")
36+
37+
38+
def create_engine():
39+
"""create a pytorch engine"""
40+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
41+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
42+
43+
start = time.perf_counter()
44+
engine = ray_engine.create_pytorch_ray_engine(
45+
model_name=FLAGS.model_name,
46+
tokenizer_path=FLAGS.tokenizer_path,
47+
ckpt_path=FLAGS.checkpoint_path,
48+
bf16_enable=FLAGS.bf16_enable,
49+
param_size=FLAGS.size,
50+
context_length=FLAGS.context_length,
51+
batch_size=FLAGS.batch_size,
52+
quantize_weights=FLAGS.quantize_weights,
53+
quantize_kv=FLAGS.quantize_kv_cache,
54+
max_cache_length=FLAGS.max_cache_length,
55+
sharding_config=FLAGS.sharding_config,
56+
)
57+
58+
print("Initialize engine", time.perf_counter() - start)
59+
return engine
60+
61+
62+
# pylint: disable-next=all
63+
def main(argv: Sequence[str]):
64+
del argv
65+
os.environ["XLA_FLAGS"] = "--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text"
66+
devices = []
67+
for i in range(FLAGS.tpu_chips):
68+
devices.append(i)
69+
70+
print(f"devices: {devices}")
71+
72+
engine = create_engine()
73+
74+
server_config = ServerConfig(
75+
interleaved_slices=(f"tpu={len(devices)}",),
76+
interleaved_engine_create_fns=(lambda a: engine,),
77+
)
78+
print(f"server_config: {server_config}")
79+
80+
jetstream_server = server_lib.run(
81+
threads=FLAGS.threads,
82+
port=FLAGS.port,
83+
config=server_config,
84+
devices=devices,
85+
jax_padding=False, # Jax_padding must be set as False
86+
)
87+
print("Started jetstream_server....")
88+
jetstream_server.wait_for_termination()
89+
90+
91+
if __name__ == "__main__":
92+
app.run(main)

0 commit comments

Comments
 (0)