1+ # SPDX-License-Identifier: Apache-2.0
2+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+ # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
4+ #
5+ # This file contains code copied from the flash-linear-attention project.
6+ # The original source code was licensed under the MIT license and included
7+ # the following copyright notice:
8+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
9+ # ruff: noqa: E501
10+ # mypy: ignore-errors
11+ import warnings
12+ from typing import Optional
13+
14+ import torch
15+ from einops import rearrange
16+ from vllm .model_executor .layers .fla .ops .l2norm import l2norm_fwd
17+ from vllm .model_executor .layers .fla .ops .utils import SUPPRESS_LEVEL
18+
19+ from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
20+ from .chunk_o import chunk_fwd_o
21+ from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
22+ from .cumsum import chunk_local_cumsum
23+ from .solve_tril import solve_tril
24+ from .utils import input_guard
25+ from .wy_fast import recompute_w_u_fwd
26+
27+
28+ def chunk_gated_delta_rule_fwd (q : torch .Tensor ,
29+ k : torch .Tensor ,
30+ v : torch .Tensor ,
31+ g : torch .Tensor ,
32+ beta : torch .Tensor ,
33+ scale : float ,
34+ initial_state : torch .Tensor ,
35+ output_final_state : bool ,
36+ cu_seqlens : Optional [torch .LongTensor ] = None ):
37+ g = chunk_local_cumsum (g , chunk_size = 64 , cu_seqlens = cu_seqlens )
38+ # obtain WY representation. u is actually the new v.
39+ A = chunk_scaled_dot_kkt_fwd (k = k ,
40+ beta = beta ,
41+ g_cumsum = g ,
42+ cu_seqlens = cu_seqlens ,
43+ output_dtype = torch .float32 )
44+ A = solve_tril (A = A , cu_seqlens = cu_seqlens , output_dtype = k .dtype )
45+ w , u = recompute_w_u_fwd (
46+ k = k ,
47+ v = v ,
48+ beta = beta ,
49+ A = A ,
50+ g_cumsum = g ,
51+ cu_seqlens = cu_seqlens ,
52+ )
53+ h , v_new , final_state = chunk_gated_delta_rule_fwd_h (
54+ k = k ,
55+ w = w ,
56+ u = u ,
57+ g = g ,
58+ initial_state = initial_state ,
59+ output_final_state = output_final_state ,
60+ cu_seqlens = cu_seqlens ,
61+ )
62+ o = chunk_fwd_o (
63+ q = q ,
64+ k = k ,
65+ v = v_new ,
66+ h = h ,
67+ g = g ,
68+ scale = scale ,
69+ cu_seqlens = cu_seqlens ,
70+ )
71+ if SUPPRESS_LEVEL < 3 :
72+ return g , o , A , final_state , None , None , None
73+ elif SUPPRESS_LEVEL >= 3 :
74+ return g , o , A , final_state , w , h , v_new
75+
76+
77+ class ChunkGatedDeltaRuleFunction (torch .autograd .Function ):
78+
79+ @staticmethod
80+ @input_guard
81+ def forward (ctx ,
82+ q : torch .Tensor ,
83+ k : torch .Tensor ,
84+ v : torch .Tensor ,
85+ g : torch .Tensor ,
86+ beta : torch .Tensor ,
87+ scale : float ,
88+ initial_state : torch .Tensor ,
89+ output_final_state : bool ,
90+ cu_seqlens : Optional [torch .LongTensor ] = None ,
91+ use_qk_l2norm_in_kernel : bool = False ):
92+ if use_qk_l2norm_in_kernel :
93+ q = l2norm_fwd (q )
94+ k = l2norm_fwd (k )
95+
96+ g , o , A , final_state , w , h , v_new = chunk_gated_delta_rule_fwd (
97+ q = q ,
98+ k = k ,
99+ v = v ,
100+ g = g ,
101+ beta = beta ,
102+ scale = scale ,
103+ initial_state = initial_state ,
104+ output_final_state = output_final_state ,
105+ cu_seqlens = cu_seqlens ,
106+ )
107+ ctx .scale = scale
108+ ctx .use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
109+ return o .to (q .dtype ), final_state
110+
111+
112+ @torch .compiler .disable
113+ def chunk_gated_delta_rule (q : torch .Tensor ,
114+ k : torch .Tensor ,
115+ v : torch .Tensor ,
116+ g : torch .Tensor ,
117+ beta : torch .Tensor ,
118+ scale : float = None ,
119+ initial_state : torch .Tensor = None ,
120+ output_final_state : bool = False ,
121+ cu_seqlens : Optional [torch .LongTensor ] = None ,
122+ head_first : bool = False ,
123+ use_qk_l2norm_in_kernel : bool = False ):
124+ r"""
125+ Args:
126+ q (torch.Tensor):
127+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
128+ k (torch.Tensor):
129+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
130+ v (torch.Tensor):
131+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
132+ g (torch.Tensor):
133+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
134+ beta (torch.Tensor):
135+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
136+ scale (Optional[int]):
137+ Scale factor for the RetNet attention scores.
138+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
139+ initial_state (Optional[torch.Tensor]):
140+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
141+ For equal-length input sequences, `N` equals the batch size `B`.
142+ Default: `None`.
143+ output_final_state (Optional[bool]):
144+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
145+ cu_seqlens (torch.LongTensor):
146+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
147+ consistent with the FlashAttention API.
148+ head_first (Optional[bool]):
149+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
150+ Default: `False`.
151+
152+ Returns:
153+ o (torch.Tensor):
154+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
155+ final_state (torch.Tensor):
156+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
157+
158+ Examples::
159+ >>> import torch
160+ >>> import torch.nn.functional as F
161+ >>> from einops import rearrange
162+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
163+ # inputs with equal lengths
164+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
165+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
166+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
167+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
168+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
169+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
170+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
171+ >>> o, ht = chunk_gated_delta_rule(
172+ q, k, v, g, beta,
173+ initial_state=h0,
174+ output_final_state=True
175+ )
176+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
177+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
178+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
179+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
180+ >>> o_var, ht_var = chunk_gated_delta_rule(
181+ q, k, v, g, beta,
182+ initial_state=h0,
183+ output_final_state=True,
184+ cu_seqlens=cu_seqlens
185+ )
186+ """
187+ assert q .dtype == k .dtype == v .dtype
188+ assert q .dtype != torch .float32 , "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
189+ assert len (
190+ beta .shape
191+ ) == 3 , "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
192+
193+ if head_first :
194+ raise DeprecationWarning (
195+ "head_first is deprecated and will be removed in a future version. "
196+ "Please use head_first=False for now instead." ,
197+ stacklevel = 2 )
198+ q , k , v , beta , g = map (
199+ lambda x : rearrange (x , 'b h t ... -> b t h ...' ),
200+ (q , k , v , beta , g ))
201+ if not head_first and q .shape [1 ] < q .shape [2 ]:
202+ warnings .warn (
203+ f"Input tensor shape suggests potential format mismatch: seq_len ({ q .shape [1 ]} ) < num_heads ({ q .shape [2 ]} ). "
204+ "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
205+ "when head_first=False was specified. "
206+ "Please verify your input tensor format matches the expected shape [B, T, H, ...]." ,
207+ stacklevel = 2 )
208+ if cu_seqlens is not None :
209+ if q .shape [0 ] != 1 :
210+ raise ValueError (
211+ f"The batch size is expected to be 1 rather than { q .shape [0 ]} when using `cu_seqlens`."
212+ f"Please flatten variable-length inputs before processing." )
213+ if initial_state is not None and initial_state .shape [0 ] != len (
214+ cu_seqlens ) - 1 :
215+ raise ValueError (
216+ f"The number of initial states is expected to be equal to the number of input sequences, "
217+ f"i.e., { len (cu_seqlens ) - 1 } rather than { initial_state .shape [0 ]} ."
218+ )
219+ if scale is None :
220+ scale = k .shape [- 1 ]** - 0.5
221+ o , final_state = ChunkGatedDeltaRuleFunction .apply (
222+ q , k , v , g , beta , scale , initial_state , output_final_state , cu_seqlens ,
223+ use_qk_l2norm_in_kernel )
224+ if head_first :
225+ o = rearrange (o , 'b t h ... -> b h t ...' )
226+ return o , final_state
0 commit comments