File tree Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Expand file tree Collapse file tree 2 files changed +4
-3
lines changed Original file line number Diff line number Diff line change 1+ from jetstream_pt .engine import create_pytorch_engine
Original file line number Diff line number Diff line change 66from absl import flags
77
88from jetstream .core import server_lib
9+ import jetstream_pt
910from jetstream_pt import config
10- from jetstream_pt import engine as je
11+ from jetstream . core . config_lib import ServerConfig
1112
1213
1314_PORT = flags .DEFINE_integer ('port' , 9000 , 'port to listen on' )
6162_QUANTIZE_KV_CACHE = flags .DEFINE_bool ('quantize_kv_cache' , False , 'kv_cache_quantize' )
6263_MAX_CACHE_LENGTH = flags .DEFINE_integer ('max_cache_length' , 1024 , 'kv_cache_quantize' )
6364
64- from jetstream .core .config_lib import ServerConfig
6565
6666def main (argv : Sequence [str ]):
6767 del argv
6868 os .environ ['XLA_FLAGS' ] = '--xla_dump_to=/tmp/xla_logs --xla_dump_hlo_as_text'
6969 # No devices for local cpu test. A None for prefill and a None for generate.
7070 devices = server_lib .get_devices ()
7171 print (f"devices: { devices } " )
72- engine = je .create_pytorch_engine (
72+ engine = jetstream_pt .create_pytorch_engine (
7373 devices = devices ,
7474 tokenizer_path = _TOKENIZER_PATH .value ,
7575 ckpt_path = _CKPT_PATH .value ,
You can’t perform that action at this time.
0 commit comments