@@ -6,15 +6,16 @@ JetStream Engine implementation in PyTorch
66
77### 1. Get the jetstream-pytorch code
88``` bash
9- git clone https://github.com/pytorch-tpu /jetstream-pytorch.git
9+ git clone https://github.com/google /jetstream-pytorch.git
1010```
1111
12121.1 (optional) Create a virtual env using ` venv ` or ` conda ` and activate it.
1313
1414### 2. Run installation script:
1515
1616``` bash
17- sh install_everything.sh
17+ cd jetstream-pytorch
18+ source install_everything.sh
1819```
1920
2021
@@ -38,35 +39,38 @@ python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_ch
3839
3940# Local run
4041
41- ## Llama 7b
42+ Set tokenizer path
43+ ``` bash
44+ export tokenizer_path=tokenizer model file path from meta-llama
4245```
43- python benchmarks/run_offline.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model
46+
47+ ## Llama 7b
48+ ``` bash
49+ python run_interactive.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir /model.safetensors --tokenizer_path=$tokenizer_path
4450```
4551
4652## Llama 13b
53+ ``` bash
54+ python run_interactive.py --size=13b --batch_size=64 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir /model.safetensors --tokenizer_path=$tokenizer_path
4755```
48- python benchmarks/run_offline.py --size=13b --batch_size=96 --max_cache_length=1280 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model
49- ```
50- NOTE: for 13b model we recommend to use ` --max_cache_length=1280 ` , i.e. this implements sliding window attention.
5156
5257
5358# Run the server
5459NOTE: the ` --platform=tpu=8 ` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)
5560
5661``` bash
57- python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir /model.safetensors --tokenizer_path=tokenizer.model --platform=tpu=8
62+ python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir /model.safetensors --tokenizer_path=$tokenizer_path --platform=tpu=8
5863```
5964Now you can fire gRPC to it
6065
6166# Run benchmark
6267go to the deps/JetStream folder (downloaded during ` install_everything.sh ` )
63- ``` bash
64- cd deps/JetStream
65- python benchmark_serving.py --tokenizer /home/hanq/jetstream-pytorch/tokenizer.model --num-prompts 2000 --dataset ~ /data/ShareGPT_V3_unfiltered_cleaned_split.json --warmup-first=1 --save-request-outputs
66- ```
67- The ShareGPT dataset can be downloaded at
6868
6969``` bash
70+ cd deps/JetStream
7071wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
72+ export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json
73+ pip install -e .
74+ python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs
7175```
7276Please look at ` deps/JetStream/benchmarks/README.md ` for more information.
0 commit comments