Optimizing Token Generation in llama.cpp's CUDA Backend #17621
am17an
started this conversation in
Show and tell
Replies: 1 comment 1 reply
-
|
The NSight diagram link doesn't work (also took the liberty to edit my name since everyone else was on a first-name basis). :) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
LLM inference is divided into 2 phases: prompt-processing (PP) and token-generation (TG). Prompt processing is when entire prompt is fed through the model, whereas token-generation is when the model starts outputting one token at a time. These are different workloads, for decoder-only auto-regressive LLMs PP is compute-bound and TG is memory-bound.
This post covers optimizations for TG using kernel fusion and concurrent streams in the CUDA backend for llama.cpp. Other backends also do fusion but differently, and the rest of this post assumes the CUDA backend. This post is lightly technical with links to PRs for people who want to go into further detail.
I ran tests on 4090 and 5090 with baseline as these optimizations disabled.
GGML_CUDA_GRAPH_OPT=1 ./llama-bench -m qwen3_30b_qk_m.gguf,gpt_oss_20b-mxfp4.gguf -fa 1 -p 0vs baseline
GGML_CUDA_DISABLE_FUSION=1 GGML_CUDA_GRAPH_OPT=0 ./llama-bench -m qwen3_30b_qk_m.gguf,gpt_oss_20b-mxfp4.gguf -fa 1 -p 0According to my rough napkin math 1, the theoretical speed-of-light for
gpt-oss-20bon a 5090 is ~520 t/s, and we are able to reach ~419 t/s (~80%), although speedups are more modest as context size grows.Most of the optimizations are enabled by default in llama.cpp, concurrent streams is enabled via an environment flag
GGML_CUDA_GRAPH_OPT=1(note that this particular optimization only works for a single GPU at the moment, also note that this has nothing to with CUDA Graphs)Kernel Fusion
Fusing kernels reduces memory traffic and kernel launch time. Since TG is memory-bound, kernel fusion helps there the most. From experience fusion is the most useful when we have high memory bandwidth, conversely the results are not as great when the bandwidth is low.
Note that some fusion will increase register pressure and generally leads to more complex kernels, so it's not always worth doing.
llama.cpp's backend is based on ggml, a low level tensor library. ggml offers a lot of flexibility to backends via its cgraph (compute graph) interface. After upfront tensor allocation and initial graph ordering, it passes control to individual backends to compute the graph. The graph itself is composed of
ggml_tensors, which contain all the information required for fusion.The backends are then responsible for fusion. The CUDA backend inspects order of operations and decides whether to fuse a bunch of individual kernels. This pattern can be described in pseudo-code:
In ggml-cuda, the relevant portions start from here
Looking at the order of nodes, one can identify opportunities for fusion. Something which is commonly used in a lot of models is usually a good candidate. The following are the most effective fusions that the CUDA backend does at the moment:
GEMV fusion
Matrix-vector multiplication is used instead of matrix-matrix multiplication during TG as we operate on 1 token at a time. We identify operations where$o = \sigma(W_{gate} X)\odot W_{up}X$ , so we can re-use the X activation and multiply by both matrices. More details here
MUL_MATfollows anADDor more interestingly, when there is a gated activation i.e. the operationTopK-MoE
This is the common scoring algorithm in MoE models to select expert weights per token. This usually involves a soft-max over the logits per expert, then a top-k selection for the experts for the sparse MoE matrix multiplication (called MUL_MAT_ID in ggml). More details here
Fused RMS Norm
This is a relatively simple one, usually
RMS_NORMfollows aMULand optionally anADD. We just do everything in the same kernel. More details here.Others include fusing adds, fusing the softcap operation.
Concurrent CUDA Streams
Since query (Q), key (K), and value(V) projections in the attention block are independent of each other prior to calculating the attention scores, we can run them independently using CUDA streams. This is better explained with a NSight diagram.

The ggml graph allocator optimizes for memory footprint by reusing 'expired' buffers immediately. In a multi-stream context, 'expired' is ambiguous, leading to race conditions where Stream B overwrites Stream A's input. This is currently "solved" via interleaving the graph nodes to extend their lifetime. There will be a better long-term fix in ggml to allow for such scenarios. More details here
Conclusion
Although none of these PRs increase the TG by more than 10%, taking them all altogether we get a nice speedup.
Acknowledgements
Johannes, Jeff, Georgi, Oliver, Sigbjørn
Footnotes
speed-of-light (SoL) napkin math ↩
Beta Was this translation helpful? Give feedback.
All reactions