Skip to content

Commit cadb490

Browse files
authored
Move create_pytorch_engine to init. (#7)
Makes the UX a little bit cleaner
1 parent 21d1290 commit cadb490

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

jetstream_pt/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from jetstream_pt.engine import create_pytorch_engine

run_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from absl import flags
77

88
from jetstream.core import server_lib
9+
import jetstream_pt
910
from jetstream_pt import config
10-
from jetstream_pt import engine as je
11+
from jetstream.core.config_lib import ServerConfig
1112

1213

1314
_PORT = flags.DEFINE_integer('port', 9000, 'port to listen on')
@@ -61,15 +62,14 @@
6162
_QUANTIZE_KV_CACHE = flags.DEFINE_bool('quantize_kv_cache', False, 'kv_cache_quantize')
6263
_MAX_CACHE_LENGTH = flags.DEFINE_integer('max_cache_length', 1024, 'kv_cache_quantize')
6364

64-
from jetstream.core.config_lib import ServerConfig
6565

6666
def main(argv: Sequence[str]):
6767
del argv
6868
os.environ['XLA_FLAGS'] = '--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text'
6969
# No devices for local cpu test. A None for prefill and a None for generate.
7070
devices = server_lib.get_devices()
7171
print(f"devices: {devices}")
72-
engine = je.create_pytorch_engine(
72+
engine = jetstream_pt.create_pytorch_engine(
7373
devices=devices,
7474
tokenizer_path=_TOKENIZER_PATH.value,
7575
ckpt_path=_CKPT_PATH.value,

0 commit comments

Comments
 (0)