Skip to content

Commit 2f8bd6f

Browse files
authored
[#9150][feat] AutoDeploy Nemotron-Flash support (#9504)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent c2562fc commit 2f8bd6f

File tree

21 files changed

+1856
-12
lines changed

21 files changed

+1856
-12
lines changed

docs/source/features/auto_deploy/support_matrix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ In addition, the following models have been officially validated using the defau
8383
- nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8
8484
- nvidia/Llama-3_3-Nemotron-Super-49B-v1
8585
- nvidia/Mistral-NeMo-Minitron-8B-Base
86+
- nvidia/Nemotron-Flash-3B-Instruct
8687
- perplexity-ai/r1-1776-distill-llama-70b
8788

8889
</details>

examples/auto_deploy/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ benchmark_results.json
55
# ignore config files that users might put here for debugging
66
*.yaml
77
!nano_v3.yaml
8+
!nemotron_flash.yaml
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
compile_backend: torch-cudagraph
2+
max_batch_size: 384
3+
max_seq_len: 2097152
4+
max_num_tokens: 8192
5+
enable_chunked_prefill: true
6+
model_factory: NemotronFlashForCausalLM
7+
free_mem_ratio: 0.9
8+
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 24, 32, 64,96, 128, 256, 320, 384]
9+
kv_cache_config:
10+
# disable kv_cache reuse since not supported for hybrid/ssm models
11+
enable_block_reuse: false

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ transforms:
152152
insert_cached_causal_conv:
153153
stage: cache_init
154154
backend: cuda_causal_conv
155+
insert_cached_delta_rule:
156+
stage: cache_init
157+
backend: fla_delta
155158
initialize_cache:
156159
stage: cache_init
157160
run_per_gm: false

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ class CacheConfig(BaseModel):
3434

3535
dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.")
3636
mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.")
37+
delta_dtype: Optional[torch.dtype] = Field(
38+
default=torch.float32, description="Delta cache dtype. Defaults to float32."
39+
)
3740

38-
@field_validator("dtype", "mamba_dtype", mode="before")
41+
@field_validator("dtype", "mamba_dtype", "delta_dtype", mode="before")
3942
@classmethod
4043
def _coerce_dtype(cls, value):
4144
if value is None or isinstance(value, torch.dtype):

tensorrt_llm/_torch/auto_deploy/custom_ops/fla/__init__.py

Whitespace-only changes.

tensorrt_llm/_torch/auto_deploy/custom_ops/fla/delta_rule/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/chunk.py
2+
3+
from typing import Optional
4+
5+
import torch
6+
7+
from tensorrt_llm._torch.modules.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
8+
from tensorrt_llm._torch.modules.fla.chunk_o import chunk_fwd_o
9+
10+
from .wy_fast import prepare_wy_repr_fwd
11+
12+
13+
def chunk_delta_rule_fwd(
14+
q: torch.Tensor,
15+
k: torch.Tensor,
16+
v: torch.Tensor,
17+
beta: torch.Tensor,
18+
scale: float,
19+
initial_state: torch.Tensor,
20+
output_final_state: bool,
21+
cu_seqlens: Optional[torch.LongTensor] = None,
22+
):
23+
# obtain WY representation. u is actually the new v.
24+
w, u, A = prepare_wy_repr_fwd(
25+
k=k,
26+
v=v,
27+
beta=beta,
28+
cu_seqlens=cu_seqlens,
29+
)
30+
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
31+
k=k,
32+
w=w,
33+
u=u,
34+
g=None,
35+
initial_state=initial_state,
36+
output_final_state=output_final_state,
37+
cu_seqlens=cu_seqlens,
38+
)
39+
40+
o = chunk_fwd_o(q=q, k=k, v=v_new, h=h, g=None, scale=scale, cu_seqlens=cu_seqlens)
41+
42+
return o, A, final_state
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/delta_rule/fused_recurrent.py
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
import triton
6+
import triton.language as tl
7+
8+
9+
@triton.heuristics(
10+
{
11+
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
12+
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
13+
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
14+
}
15+
)
16+
@triton.jit(do_not_specialize=["T"])
17+
def fused_recurrent_delta_rule_fwd_kernel(
18+
q,
19+
k,
20+
v,
21+
u,
22+
beta,
23+
o,
24+
h0,
25+
ht,
26+
cu_seqlens,
27+
scale,
28+
T,
29+
B: tl.constexpr,
30+
H: tl.constexpr,
31+
K: tl.constexpr,
32+
V: tl.constexpr,
33+
BK: tl.constexpr,
34+
BV: tl.constexpr,
35+
USE_INITIAL_STATE: tl.constexpr,
36+
STORE_FINAL_STATE: tl.constexpr,
37+
IS_BETA_HEADWISE: tl.constexpr,
38+
IS_VARLEN: tl.constexpr,
39+
):
40+
i_v, i_k, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
41+
i_n, i_h = i_nh // H, i_nh % H
42+
if IS_VARLEN:
43+
bos, eos = (
44+
tl.load(cu_seqlens + i_n).to(tl.int64),
45+
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
46+
)
47+
all = T
48+
T = eos - bos
49+
else:
50+
bos, eos = i_n * T, i_n * T + T
51+
all = B * T
52+
53+
p_q = q + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
54+
p_k = k + (bos * H + i_h) * K + i_k * BK + tl.arange(0, BK)
55+
p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
56+
p_u = u + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
57+
if IS_BETA_HEADWISE:
58+
p_beta = beta + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV)
59+
else:
60+
p_beta = beta + bos * H + i_h
61+
p_o = o + ((i_k * all + bos) * H + i_h) * V + i_v * BV + tl.arange(0, BV)
62+
63+
mask_k = (i_k * BK + tl.arange(0, BK)) < K
64+
mask_v = (i_v * BV + tl.arange(0, BV)) < V
65+
mask_h = mask_k[None, :] & mask_v[:, None]
66+
67+
b_h = tl.zeros([BV, BK], dtype=tl.float32)
68+
if USE_INITIAL_STATE:
69+
p_h0 = (
70+
h0
71+
+ i_nh * K * V
72+
+ (i_k * BK + tl.arange(0, BK)[None, :]) * V
73+
+ (i_v * BV + tl.arange(0, BV)[:, None])
74+
)
75+
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
76+
77+
for _ in range(0, T):
78+
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
79+
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
80+
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale
81+
b_v_minus = tl.sum(b_h * b_k[None, :], axis=1)
82+
b_v -= b_v_minus
83+
if IS_BETA_HEADWISE:
84+
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
85+
else:
86+
b_beta = tl.load(p_beta).to(tl.float32)
87+
tl.store(p_u, b_v.to(p_v.dtype.element_ty), mask=mask_v)
88+
b_v *= b_beta
89+
b_h += b_k[None, :] * b_v[:, None]
90+
b_o = b_h * b_q[None, :]
91+
b_o = tl.sum(b_o, axis=1)
92+
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
93+
94+
p_q += H * K
95+
p_k += H * K
96+
p_o += H * V
97+
p_v += H * V
98+
p_u += H * V
99+
p_beta += H * (V if IS_BETA_HEADWISE else 1)
100+
101+
if STORE_FINAL_STATE:
102+
p_ht = (
103+
ht
104+
+ i_nh * K * V
105+
+ (i_k * BK + tl.arange(0, BK)[None, :]) * V
106+
+ (i_v * BV + tl.arange(0, BV)[:, None])
107+
)
108+
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
109+
110+
111+
def fused_recurrent_delta_rule_fwd(
112+
q: torch.Tensor,
113+
k: torch.Tensor,
114+
v: torch.Tensor,
115+
beta: torch.Tensor,
116+
scale: float,
117+
initial_state: torch.Tensor,
118+
output_final_state: bool,
119+
cu_seqlens: Optional[torch.LongTensor] = None,
120+
) -> Tuple[torch.Tensor, torch.Tensor]:
121+
B, T, H, K, V = *k.shape, v.shape[-1]
122+
N = B if cu_seqlens is None else len(cu_seqlens) - 1
123+
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
124+
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
125+
assert NK == 1, "NK > 1 is not supported yet"
126+
num_stages = 1
127+
num_warps = 1
128+
129+
o = q.new_empty(NK, *v.shape)
130+
if output_final_state:
131+
final_state = q.new_empty(N, H, K, V, dtype=torch.float32)
132+
else:
133+
final_state = None
134+
135+
grid = (NV, NK, N * H)
136+
u = torch.empty_like(v)
137+
fused_recurrent_delta_rule_fwd_kernel[grid](
138+
q,
139+
k,
140+
v,
141+
u,
142+
beta,
143+
o,
144+
initial_state,
145+
final_state,
146+
cu_seqlens,
147+
scale,
148+
T=T,
149+
B=B,
150+
H=H,
151+
K=K,
152+
V=V,
153+
BK=BK,
154+
BV=BV,
155+
IS_BETA_HEADWISE=beta.ndim == v.ndim,
156+
num_warps=num_warps,
157+
num_stages=num_stages,
158+
)
159+
o = o.squeeze(0)
160+
return o, u, final_state
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
2+
import inspect
3+
import os
4+
5+
import triton
6+
7+
FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1"
8+
9+
10+
supports_autotune_cache = "cache_results" in inspect.signature(triton.autotune).parameters
11+
autotune_cache_kwargs = {"cache_results": FLA_CACHE_RESULTS} if supports_autotune_cache else {}

0 commit comments

Comments
 (0)