|
1 | | -# petstream |
| 1 | +# Jetstream-PyTorch |
2 | 2 | JetStream Engine implementation in PyTorch |
3 | 3 |
|
4 | 4 |
|
5 | | -# Install torch_xla2 |
| 5 | +# Install |
6 | 6 |
|
| 7 | +### 1. Get the jetstream-pytorch code |
7 | 8 | ```bash |
8 | | -git clone https://github.com/pytorch/xla.git |
9 | | -cd xla/experimental/torch_xla2 |
10 | | -pip install -e . |
| 9 | +git clone https://github.com/pytorch-tpu/jetstream-pytorch.git |
11 | 10 | ``` |
12 | 11 |
|
13 | | -# Merge weights |
| 12 | +1.1 (optional) Create a virtual env using `venv` or `conda` and activate it. |
| 13 | + |
| 14 | +### 2. Run installation script: |
| 15 | + |
| 16 | +```bash |
| 17 | +source install_everything.sh |
14 | 18 | ``` |
15 | | -export input_ckpt_dir = Original sharded pytorch checkpoints |
16 | | -export output_ckpt_dir = The output director |
17 | | -export output_safetensor = True/False, user can choose to store as SafeTensor |
18 | | -format or not |
19 | | -python petstream/pets/weight_merger.py --input_ckpt_dir={{input_ckpt_dir}} --output_ckpt_dir={{output_ckpt_dir}} --output_safetensor={{output_safetensor}} |
20 | 19 |
|
21 | | -If user choose to load or store the checkpoints from Google Cloud Storage |
22 | | -buckets, please make sure run `gcloud auth application-default login` beforehand |
| 20 | +NOTE: the above script will export `PYTHONPATH`, so sourcing will make it |
| 21 | +to take effect in the current shell |
| 22 | + |
| 23 | + |
| 24 | +# Get weights |
| 25 | + |
| 26 | +### First get official llama weights from meta-llama |
| 27 | + |
| 28 | +Following instructions here: https://github.com/meta-llama/llama#download |
| 29 | + |
| 30 | +After you have downloaded the weights, it will also download a `tokenizer.model` file that is |
| 31 | +the tokenizer that we will use. |
| 32 | + |
| 33 | +### Run weight merger to convert (and ) |
| 34 | +```bash |
| 35 | +export input_ckpt_dir=Original llama weights directory |
| 36 | +export output_ckpt_dir=The output directory |
| 37 | +export quantize=True #whether to quantize |
| 38 | +python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize |
23 | 39 | ``` |
24 | 40 |
|
25 | 41 |
|
26 | 42 | # Local run |
| 43 | + |
| 44 | +## Llama 7b |
27 | 45 | ``` |
28 | | -python -m petstream.jet_engine_python_run --bf16_enable=True --context_length=8 --batch_size=2 |
| 46 | +python benchmarks/run_offline.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model |
29 | 47 | ``` |
30 | 48 |
|
31 | | -# Bring up server |
| 49 | +## Llama 13b |
32 | 50 | ``` |
33 | | -python -m run_server |
34 | | -By default it runs on 'tpu=4', add --platform='cpu=1' if you are running on CPU |
35 | | -By default it runs with tiny model, add --param_size='7b' to run 7b model |
36 | | -
|
37 | | -Firing the request with: |
38 | | -python jetstream/core/tools/requester.py |
| 51 | +python benchmarks/run_offline.py --size=13b --batch_size=96 --max_cache_length=1280 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model |
39 | 52 | ``` |
| 53 | +NOTE: for 13b model we recommend to use `--max_cache_length=1280`, i.e. this implements sliding window attention. |
| 54 | + |
| 55 | + |
| 56 | +# Run the server |
| 57 | +NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`) |
40 | 58 |
|
41 | | -# Profiling |
| 59 | +```bash |
| 60 | +python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model --platform=tpu=8 |
42 | 61 | ``` |
43 | | -export profiling_output = Some gcs bucket |
44 | | -python -m petstream.jet_engine_python_run --bf16_enable=True --context_length=8 --batch_size=2 --profiling_output={{profiling_output}} |
| 62 | +Now you can fire gRPC to it |
45 | 63 |
|
46 | | -Switch to your Cloud top, run: |
47 | | -export profiling_result = Some google generated folder in your gcs bucket |
48 | | -petstream/gcs_to_cns.sh {{profiling_result}} |
| 64 | +# Run benchmark |
| 65 | +go to the deps/JetStream folder (downloaded during `install_everything.sh`) |
| 66 | +```bash |
| 67 | +cd deps/JetStream |
| 68 | +python benchmark_serving.py --tokenizer /home/hanq/jetstream-pytorch/tokenizer.model --num-prompts 2000 --dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json --warmup-first=1 --save-request-outputs |
| 69 | +``` |
| 70 | +The ShareGPT dataset can be downloaded at |
49 | 71 |
|
50 | | -The dump will always be in this directory: /cns/pi-d/home/{USER}/tensorboard/multislice/, load to Xprof/Offeline/Xplane |
| 72 | +```bash |
| 73 | +wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json |
51 | 74 | ``` |
| 75 | +Please look at `deps/JetStream/benchmarks/README.md` for more information. |
0 commit comments