Skip to content

Commit 04f1401

Browse files
committed
Update with the most current version
1 parent b7ebb23 commit 04f1401

33 files changed

+1610
-878
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# source dependencies
2+
deps/
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

README.md

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,75 @@
1-
# petstream
1+
# Jetstream-PyTorch
22
JetStream Engine implementation in PyTorch
33

44

5-
# Install torch_xla2
5+
# Install
66

7+
### 1. Get the jetstream-pytorch code
78
```bash
8-
git clone https://github.com/pytorch/xla.git
9-
cd xla/experimental/torch_xla2
10-
pip install -e .
9+
git clone https://github.com/pytorch-tpu/jetstream-pytorch.git
1110
```
1211

13-
# Merge weights
12+
1.1 (optional) Create a virtual env using `venv` or `conda` and activate it.
13+
14+
### 2. Run installation script:
15+
16+
```bash
17+
source install_everything.sh
1418
```
15-
export input_ckpt_dir = Original sharded pytorch checkpoints
16-
export output_ckpt_dir = The output director
17-
export output_safetensor = True/False, user can choose to store as SafeTensor
18-
format or not
19-
python petstream/pets/weight_merger.py --input_ckpt_dir={{input_ckpt_dir}} --output_ckpt_dir={{output_ckpt_dir}} --output_safetensor={{output_safetensor}}
2019

21-
If user choose to load or store the checkpoints from Google Cloud Storage
22-
buckets, please make sure run `gcloud auth application-default login` beforehand
20+
NOTE: the above script will export `PYTHONPATH`, so sourcing will make it
21+
to take effect in the current shell
22+
23+
24+
# Get weights
25+
26+
### First get official llama weights from meta-llama
27+
28+
Following instructions here: https://github.com/meta-llama/llama#download
29+
30+
After you have downloaded the weights, it will also download a `tokenizer.model` file that is
31+
the tokenizer that we will use.
32+
33+
### Run weight merger to convert (and )
34+
```bash
35+
export input_ckpt_dir=Original llama weights directory
36+
export output_ckpt_dir=The output directory
37+
export quantize=True #whether to quantize
38+
python -m convert_checkpoints --input_checkpoint_dir=$input_ckpt_dir --output_checkpoint_dir=$output_ckpt_dir --quantize=$quantize
2339
```
2440

2541

2642
# Local run
43+
44+
## Llama 7b
2745
```
28-
python -m petstream.jet_engine_python_run --bf16_enable=True --context_length=8 --batch_size=2
46+
python benchmarks/run_offline.py --size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model
2947
```
3048

31-
# Bring up server
49+
## Llama 13b
3250
```
33-
python -m run_server
34-
By default it runs on 'tpu=4', add --platform='cpu=1' if you are running on CPU
35-
By default it runs with tiny model, add --param_size='7b' to run 7b model
36-
37-
Firing the request with:
38-
python jetstream/core/tools/requester.py
51+
python benchmarks/run_offline.py --size=13b --batch_size=96 --max_cache_length=1280 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model
3952
```
53+
NOTE: for 13b model we recommend to use `--max_cache_length=1280`, i.e. this implements sliding window attention.
54+
55+
56+
# Run the server
57+
NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 for v4-8 and 8 for v5light-8`)
4058

41-
# Profiling
59+
```bash
60+
python run_server.py --param_size=7b --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir/model.safetensors --tokenizer_path=tokenizer.model --platform=tpu=8
4261
```
43-
export profiling_output = Some gcs bucket
44-
python -m petstream.jet_engine_python_run --bf16_enable=True --context_length=8 --batch_size=2 --profiling_output={{profiling_output}}
62+
Now you can fire gRPC to it
4563

46-
Switch to your Cloud top, run:
47-
export profiling_result = Some google generated folder in your gcs bucket
48-
petstream/gcs_to_cns.sh {{profiling_result}}
64+
# Run benchmark
65+
go to the deps/JetStream folder (downloaded during `install_everything.sh`)
66+
```bash
67+
cd deps/JetStream
68+
python benchmark_serving.py --tokenizer /home/hanq/jetstream-pytorch/tokenizer.model --num-prompts 2000 --dataset ~/data/ShareGPT_V3_unfiltered_cleaned_split.json --warmup-first=1 --save-request-outputs
69+
```
70+
The ShareGPT dataset can be downloaded at
4971

50-
The dump will always be in this directory: /cns/pi-d/home/{USER}/tensorboard/multislice/, load to Xprof/Offeline/Xplane
72+
```bash
73+
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
5174
```
75+
Please look at `deps/JetStream/benchmarks/README.md` for more information.

benchmarks/decode_microbenchmark.py

Lines changed: 0 additions & 214 deletions
This file was deleted.

0 commit comments

Comments
 (0)