File tree Expand file tree Collapse file tree 3 files changed +23
-2
lines changed Expand file tree Collapse file tree 3 files changed +23
-2
lines changed Original file line number Diff line number Diff line change @@ -109,7 +109,17 @@ NOTE: the `--platform=tpu=8` need to specify number of tpu devices (which is 4 f
109109``` bash
110110python run_server.py --param_size=7b --model_name=$model_name --batch_size=128 --max_cache_length=2048 --quantize_weights=$quantize --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --platform=tpu=8 --model=$model_name
111111```
112- Now you can fire gRPC to it
112+
113+ Now you can fire gRPC to it.
114+
115+ Optional flags:
116+ * ` --shard_on_batch=1 ` This makes the model to shard on
117+ the batch dimension. I.e. this runs in data parallel mode instead of model
118+ parallel. This will ignore the sharding config. This is recommended for Gemma 2B
119+ model, because Gemma 2B is small enough to fit on a single TPU chip.
120+
121+ * ` --sharding_config=<path> ` This makes use of alternative sharding config instead of
122+ the ones in default_shardings directory.
113123
114124# Run benchmark
115125go to the deps/JetStream folder (downloaded during ` install_everything.sh ` )
Original file line number Diff line number Diff line change @@ -16,6 +16,17 @@ Date | Device | dtype | batch size | cache length |max input length |max output
16162024-05-10 | TPU v5e-8 | bfloat16 | 96 | 2048 | 1024 | 1024 | 3236
17172024-05-10 | TPU v5e-8 | int8 | 128 | 2048 | 1024 | 1024 | 4695
1818
19+ ## Gemma - 2B
20+
21+ Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
22+ ----| ------- | ------ |---------- | -------------|-----------------|------------------|----------------------
23+ 2024-05-14 | TPU v5e-8 | bfloat16 | 512 | 2048 | 1024 | 1024 | 8700
24+ 2024-05-14 | TPU v5e-8 | int8 | 1024 | 2048 | 1024 | 1024 | 8746
25+
26+ ** NOTE: ** Gemma 2B uses ` --shard_on_batch ` flag so it's data parallel instead
27+ of model parallel.
28+
29+
1930## Llama 2 - 7B
2031
2132Date | Device | dtype | batch size | cache length |max input length |max output length| throughput (token/s)
Original file line number Diff line number Diff line change @@ -176,7 +176,7 @@ def make_caches_generate(self):
176176 def sharding_by_name (self , name ):
177177 """Create sharding specified in the config."""
178178 if self .shard_on_batch :
179- return self .shading_by_axis (0 ) # batch dimension
179+ return self .sharding_by_axis (0 ) # batch dimension
180180
181181 if name in self ._sharding_config :
182182 return self .sharding_by_axis (self ._sharding_config [name ])
You can’t perform that action at this time.
0 commit comments