Skip to content

Commit 96132b4

Browse files
[None] [doc] Add Mixed Precision Context and Generation section to Disagg (#8769)
Signed-off-by: Timothy Gao <35588167+timothygao8710@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 4003dc7 commit 96132b4

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

examples/disaggregated/README.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,63 @@ srun -A <account> -p <partition> -t <time> \
203203

204204
Additionally, we offer a fully executable script—please refer to [Disaggregated SLURM Scripts](./slurm/simple_example/).
205205

206+
## Mixed Precision Context and Generation
207+
208+
In disaggregated serving, the context workers and generation workers have different performance characteristics: context workers are compute-bound while generation workers are memory-bound. Therefore, it may be beneficial to run context workers and generation workers in different precisions.
209+
210+
### Prerequisites
211+
212+
To enable mixed precision serving, you will need:
213+
1. A quantized checkpoint created with [TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
214+
2. The original unquantized checkpoint (Can also be quantized)
215+
3. Both checkpoints must use the same KV cache dtype to ensure compatibility during transfer
216+
217+
### Example (BF 16 Ctx, FP 8 Gen)
218+
219+
A quantized checkpoint can be created using `--kv_cache_qformat none`.
220+
221+
```bash
222+
python $MODELOPT_ROOT/examples/llm_ptq/hf_ptq.py \
223+
--pyt_ckpt_path=meta-llama/Llama-3.1-8B-Instruct \
224+
--export_path=./weights/Llama-3.1-8B-Instruct-FP8-KV-BF16 \
225+
--sparsity_fmt=dense \
226+
--qformat=fp8 \
227+
--calib_size=512 \
228+
--batch_size=8 \
229+
--inference_tensor_parallel=1 \
230+
--inference_pipeline_parallel=1 \
231+
--kv_cache_qformat none \
232+
--export_fmt=hf
233+
```
234+
235+
Verify both checkpoints have the same KV cache dtype by checking `hf_quant_config.json`.
236+
237+
```bash
238+
# Start context servers with original BF16 checkpoint
239+
CUDA_VISIBLE_DEVICES=0 trtllm-serve meta-llama/Llama-3.1-8B-Instruct \
240+
--host localhost --port 8001 \
241+
--server_role CONTEXT \
242+
--extra_llm_api_options ./ctx_extra-llm-api-config.yaml \
243+
--metadata_server_config_file ./metadata_config.yaml &> log_ctx_0 &
244+
245+
CUDA_VISIBLE_DEVICES=1 trtllm-serve meta-llama/Llama-3.1-8B-Instruct \
246+
--host localhost --port 8002 \
247+
--server_role CONTEXT \
248+
--extra_llm_api_options ./ctx_extra-llm-api-config.yaml \
249+
--metadata_server_config_file ./metadata_config.yaml &> log_ctx_1 &
250+
251+
# Start generation server with FP8 quantized checkpoint
252+
CUDA_VISIBLE_DEVICES=2 trtllm-serve ./weights/Llama-3.1-8B-Instruct-FP8-KV-BF16 \
253+
--host localhost --port 8003 \
254+
--server_role GENERATION \
255+
--extra_llm_api_options ./gen_extra-llm-api-config.yaml \
256+
--metadata_server_config_file ./metadata_config.yaml &> log_gen_0 &
257+
258+
# Start disaggregated server
259+
trtllm-serve disaggregated -c disagg_config.yaml -m ./metadata_config.yaml
260+
```
261+
262+
You can also run FP8 for context and BF16 for generation, as long as the KV-cache dtype is consistent across all workers.
206263

207264
## Dynamic scaling
208265

0 commit comments

Comments
 (0)