|
1 | 1 | # Jetstream-PyTorch |
2 | 2 | JetStream Engine implementation in PyTorch |
3 | 3 |
|
| 4 | +# Outline |
4 | 5 |
|
5 | | -# Install |
| 6 | +1. Ssh to Cloud TPU VM (using v5e-8 TPU VM) |
| 7 | + a. Create a Cloud TPU VM if you haven’t |
| 8 | +2. Download jetstream-pytorch github repo |
| 9 | +3. Clone repo and install dependencies |
| 10 | +4. Download and convert weights |
| 11 | +5. Run checkpoint converter (quantizer) |
| 12 | +6. Local run |
| 13 | +7. Run the server |
| 14 | +8. Run benchmarks |
| 15 | +9. Typical Errors |
6 | 16 |
|
7 | | -### 1. Get the jetstream-pytorch code |
| 17 | +# Ssh to Cloud TPU VM (using v5e-8 TPU VM) |
| 18 | + |
| 19 | +```bash |
| 20 | +gcloud compute config-ssh |
| 21 | +gcloud compute tpus tpu-vm ssh "your-tpu-vm" --project "your-project" --zone "your-project-zone" |
| 22 | +``` |
| 23 | +## Create a Cloud TPU VM in a GCP project if you haven’t |
| 24 | +Follow step 1-9 in the following guide |
| 25 | +* https://cloud.google.com/tpu/docs/v5e-inference#prepare-a-project |
| 26 | + |
| 27 | +# Clone repo and install dependencies |
| 28 | + |
| 29 | +## Get the jetstream-pytorch code |
8 | 30 | ```bash |
9 | 31 | git clone https://github.com/google/jetstream-pytorch.git |
10 | 32 | ``` |
11 | 33 |
|
12 | | -1.1 (optional) Create a virtual env using `venv` or `conda` and activate it. |
| 34 | +(optional) Create a virtual env using `venv` or `conda` and activate it. |
13 | 35 |
|
14 | | -### 2. Run installation script: |
| 36 | +## 2. Run installation script: |
15 | 37 |
|
16 | 38 | ```bash |
17 | 39 | cd jetstream-pytorch |
18 | 40 | source install_everything.sh |
19 | 41 | ``` |
| 42 | +NOTE: the above script will export PYTHONPATH, so sourcing will make it to take effect in the current shell |
20 | 43 |
|
| 44 | +# Download and convert weights |
21 | 45 |
|
22 | | -# Get weights |
23 | | - |
24 | | -### First get official llama weights from meta-llama |
| 46 | +## Get official llama weights from meta-llama |
25 | 47 |
|
26 | 48 | Following instructions here: https://github.com/meta-llama/llama#download |
27 | 49 |
|
28 | 50 | After you have downloaded the weights, it will also download a `tokenizer.model` file that is |
29 | 51 | the tokenizer that we will use. |
30 | 52 |
|
31 | | -### Run weight merger to convert (and ) |
| 53 | +## Run weight safetensor convert |
| 54 | + |
32 | 55 | ```bash |
33 | 56 | export input_ckpt_dir=Original llama weights directory |
34 | 57 | export output_ckpt_dir=The output directory |
@@ -73,3 +96,20 @@ export dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json |
73 | 96 | python benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000 --dataset-path $dataset_path --dataset sharegpt --save-request-outputs |
74 | 97 | ``` |
75 | 98 | Please look at `deps/JetStream/benchmarks/README.md` for more information. |
| 99 | + |
| 100 | + |
| 101 | +# Typical Errors |
| 102 | + |
| 103 | +## Unexpected keyword argument 'device' |
| 104 | + |
| 105 | +Fix: |
| 106 | +* Uninstall jax and jaxlib dependencies |
| 107 | +* Reinstall using `source install_everything.sh |
| 108 | + |
| 109 | +## Out of memory |
| 110 | + |
| 111 | +Fix: |
| 112 | +* Use smaller batch size |
| 113 | +* Use quantization |
| 114 | + |
| 115 | + |
0 commit comments