|
| 1 | +import functools |
| 2 | +import math |
| 3 | + |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | +from jax.experimental import pallas as pl |
| 7 | +from jax.experimental.pallas import tpu as pltpu |
| 8 | +from jax.experimental.shard_map import shard_map |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn.functional as F |
| 12 | + |
| 13 | +import numpy as np |
| 14 | + |
| 15 | +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) |
| 16 | + |
| 17 | + |
| 18 | +def ragged_flash_attention_kernel( |
| 19 | + start_ref, |
| 20 | + end_ref, |
| 21 | + line_end_ref, |
| 22 | + pre_b_ref, |
| 23 | + pre_i_ref, |
| 24 | + q_ref, |
| 25 | + k_ref, |
| 26 | + v_ref, |
| 27 | + k_scaler_ref, |
| 28 | + v_scaler_ref, |
| 29 | + o_ref, # outputs |
| 30 | + m_ref, # row max |
| 31 | + l_ref, # propogation coefficient |
| 32 | + bk: int, |
| 33 | + mask_value: float, |
| 34 | + normalize_var: bool, |
| 35 | + quantized: bool, |
| 36 | +): |
| 37 | + """Pallas kernel for flash attention.""" |
| 38 | + with jax.named_scope("attention_kernel"): |
| 39 | + b, i = pl.program_id(0), pl.program_id(1) |
| 40 | + |
| 41 | + @pl.when(i == 0) |
| 42 | + def init(): |
| 43 | + with jax.named_scope("init"): |
| 44 | + m_ref[...] = jnp.full_like(m_ref, -jnp.inf) |
| 45 | + l_ref[...] = jnp.zeros_like(l_ref) |
| 46 | + o_ref[...] = jnp.zeros_like(o_ref) |
| 47 | + |
| 48 | + length = line_end_ref[b] |
| 49 | + start = start_ref[b] |
| 50 | + end = end_ref[b] |
| 51 | + |
| 52 | + @pl.when(jnp.logical_and(i * bk < length, start != end)) |
| 53 | + def run(): |
| 54 | + with jax.named_scope("run_qk"): |
| 55 | + q = q_ref[...].astype(jnp.float32) |
| 56 | + k = k_ref[...].astype(jnp.float32) |
| 57 | + v = v_ref[...].astype(jnp.float32) |
| 58 | + m_prev, l_prev = m_ref[...], l_ref[...] |
| 59 | + |
| 60 | + qk = jax.lax.dot_general( |
| 61 | + q, k, (((1,), (1,)), ((), ())), preferred_element_type=jnp.float32 |
| 62 | + ) |
| 63 | + if normalize_var: |
| 64 | + qk = qk / jnp.sqrt(k.shape[-1]) |
| 65 | + if quantized: |
| 66 | + qk = qk * k_scaler_ref[...] |
| 67 | + with jax.named_scope("run_mask"): |
| 68 | + start = start_ref[b] |
| 69 | + end = end_ref[b] |
| 70 | + iota = jax.lax.broadcasted_iota(jnp.int32, qk.shape, 1) |
| 71 | + mask_start_lt_end = jnp.logical_and( |
| 72 | + i * bk + iota >= start, i * bk + iota < end |
| 73 | + ).astype(jnp.int32) |
| 74 | + mask_start_gt_end = jnp.logical_or( |
| 75 | + i * bk + iota >= start, i * bk + iota < end |
| 76 | + ).astype(jnp.int32) |
| 77 | + # mask = jax.lax.cond(start <= end, lambda: mask_start_lt_end, lambda: mask_start_gt_end) |
| 78 | + mask = jnp.where(start <= end, mask_start_lt_end, mask_start_gt_end) |
| 79 | + |
| 80 | + qk = qk + jnp.where(mask, 0.0, mask_value) |
| 81 | + |
| 82 | + with jax.named_scope("run_softmax"): |
| 83 | + m_curr = qk.max(axis=-1) |
| 84 | + |
| 85 | + s_curr = jnp.exp(qk - m_curr[..., None]) |
| 86 | + |
| 87 | + l_curr = jax.lax.broadcast_in_dim( |
| 88 | + s_curr.sum(axis=-1), l_prev.shape, (0,) |
| 89 | + ) |
| 90 | + if quantized: |
| 91 | + s_curr = s_curr * v_scaler_ref[...] |
| 92 | + o_curr_times_l_curr = jnp.dot(s_curr, v) |
| 93 | + m_curr = jax.lax.broadcast_in_dim(m_curr, m_prev.shape, (0,)) |
| 94 | + m_next = jnp.maximum(m_prev, m_curr) |
| 95 | + alpha = jnp.exp(m_prev - m_next) |
| 96 | + beta = jnp.exp(m_curr - m_next) |
| 97 | + l_next = alpha * l_prev + beta * l_curr |
| 98 | + l_next_safe = jnp.where(l_next == 0.0, 1.0, l_next) |
| 99 | + |
| 100 | + m_ref[...], l_ref[...] = m_next, l_next_safe |
| 101 | + o_ref[...] = ( |
| 102 | + (l_prev * alpha * o_ref[...] + beta * o_curr_times_l_curr) |
| 103 | + / l_next_safe |
| 104 | + ).astype(o_ref.dtype) |
| 105 | + |
| 106 | + |
| 107 | +@functools.partial( |
| 108 | + jax.jit, static_argnames=["bk", "mask_value", "normalize_var"] |
| 109 | +) |
| 110 | +def ragged_mqa( |
| 111 | + q: jax.Array, |
| 112 | + k: jax.Array, |
| 113 | + v: jax.Array, |
| 114 | + start: jax.Array, |
| 115 | + end: jax.Array, |
| 116 | + k_scaler: jax.Array | None = None, |
| 117 | + v_scaler: jax.Array | None = None, |
| 118 | + ragged_batch_index=None, |
| 119 | + ragged_block_index=None, |
| 120 | + bk: int = 512, |
| 121 | + mask_value: float = DEFAULT_MASK_VALUE, |
| 122 | + normalize_var: bool = True, |
| 123 | +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: |
| 124 | + """Ragged multi query attention.""" |
| 125 | + with jax.named_scope("ragged_mqa"): |
| 126 | + batch_size, num_heads, head_dim = q.shape |
| 127 | + seq_len = k.shape[1] |
| 128 | + |
| 129 | + def kv_index_map( |
| 130 | + b, |
| 131 | + i, |
| 132 | + start_ref, |
| 133 | + end_ref, |
| 134 | + line_end_ref, |
| 135 | + ragged_batch_index_ref, |
| 136 | + ragged_block_index_ref, |
| 137 | + ): |
| 138 | + index = b * (seq_len // bk) + i |
| 139 | + return ragged_batch_index_ref[index], ragged_block_index_ref[index], 0 |
| 140 | + |
| 141 | + def q_index_map( |
| 142 | + b, |
| 143 | + i, |
| 144 | + start_ref, |
| 145 | + end_ref, |
| 146 | + line_end_ref, |
| 147 | + ragged_batch_index_ref, |
| 148 | + ragged_block_index_ref, |
| 149 | + ): |
| 150 | + index = b * (seq_len // bk) + i |
| 151 | + return ragged_batch_index_ref[index], 0, 0 |
| 152 | + |
| 153 | + def scaler_index_map(b, i, *_): |
| 154 | + return b, 0, i |
| 155 | + |
| 156 | + line_end = jnp.where(start < end, end, seq_len - 1) |
| 157 | + |
| 158 | + in_specs = [ |
| 159 | + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), |
| 160 | + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), |
| 161 | + pl.BlockSpec(kv_index_map, (None, bk, head_dim)), |
| 162 | + ] |
| 163 | + inputs = ( |
| 164 | + start, |
| 165 | + end, |
| 166 | + line_end, |
| 167 | + ragged_batch_index, |
| 168 | + ragged_block_index, |
| 169 | + q, |
| 170 | + k, |
| 171 | + v, |
| 172 | + ) |
| 173 | + quantized = False |
| 174 | + if k_scaler is not None: |
| 175 | + in_specs = in_specs + [ |
| 176 | + pl.BlockSpec(scaler_index_map, (None, 1, bk)), |
| 177 | + pl.BlockSpec(scaler_index_map, (None, 1, bk)), |
| 178 | + ] |
| 179 | + inputs = inputs + (k_scaler, v_scaler) |
| 180 | + quantized = True |
| 181 | + |
| 182 | + out, m, l = pl.pallas_call( |
| 183 | + functools.partial( |
| 184 | + ragged_flash_attention_kernel, |
| 185 | + bk=bk, |
| 186 | + mask_value=mask_value, |
| 187 | + normalize_var=normalize_var, |
| 188 | + quantized=quantized, |
| 189 | + ), |
| 190 | + grid_spec=pltpu.PrefetchScalarGridSpec( |
| 191 | + num_scalar_prefetch=5, |
| 192 | + in_specs=in_specs, |
| 193 | + out_specs=[ |
| 194 | + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), |
| 195 | + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), |
| 196 | + pl.BlockSpec(q_index_map, (None, num_heads, head_dim)), |
| 197 | + ], |
| 198 | + grid=(batch_size, seq_len // bk), |
| 199 | + ), |
| 200 | + compiler_params={"dimension_semantics": ("parallel", "arbitrary")}, |
| 201 | + out_shape=[ |
| 202 | + q, |
| 203 | + jax.ShapeDtypeStruct( |
| 204 | + (batch_size, num_heads, head_dim), jnp.float32 |
| 205 | + ), |
| 206 | + jax.ShapeDtypeStruct( |
| 207 | + (batch_size, num_heads, head_dim), jnp.float32 |
| 208 | + ), |
| 209 | + ], |
| 210 | + )(*inputs) |
| 211 | + return out, (m[..., 0], l[..., 0]) |
| 212 | + |
| 213 | + |
| 214 | +@functools.partial( |
| 215 | + jax.jit, static_argnames=["bk", "mask_value", "normalize_var", "shard_axis"] |
| 216 | +) |
| 217 | +def ragged_mha( |
| 218 | + q: jax.Array, |
| 219 | + k: jax.Array, |
| 220 | + v: jax.Array, |
| 221 | + start: jax.Array, |
| 222 | + end: jax.Array, |
| 223 | + ragged_batch_index: jax.Array, |
| 224 | + ragged_block_index: jax.Array, |
| 225 | + k_scaler: jax.Array | None = None, |
| 226 | + v_scaler: jax.Array | None = None, |
| 227 | + bk: int = 512, |
| 228 | + mask_value: float = DEFAULT_MASK_VALUE, |
| 229 | + normalize_var: bool = True, |
| 230 | + shard_axis: int = 1, |
| 231 | +) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: |
| 232 | + """Ragged multi head attention. |
| 233 | + Args: |
| 234 | + q: A [batch_size, compute_dim, num_heads, head_dim] jax.Array. |
| 235 | + k: A [batch_size, num_heads, seq_len, head_dim] jax.Array or |
| 236 | + PartitionQuantizedTensor. |
| 237 | + v: A [batch_size, num_heads, seq_len, head_dim] jax.Array or |
| 238 | + PartitionQuantizedTensor. |
| 239 | + start: A i32[batch_size] jax.Array |
| 240 | + end: A i32[batch_size] jax.Array |
| 241 | + bk: An integer that is the sequence block size. |
| 242 | + logit_cap: An optional float that caps logits via tanh. By default there is |
| 243 | + no logit capping. |
| 244 | + mask_value: The value used for padding in attention. By default it is a very |
| 245 | + negative floating point number. |
| 246 | + out_dtype: An optional dtype for the output. If not provided, the output |
| 247 | + dtype will be q's dtype. |
| 248 | + Returns: |
| 249 | + The output of attention([batch_size, num_heads, compute_dim, head_dim]), |
| 250 | + along with the max logit ([batch_size, num_heads, compute_dim, 1]) and |
| 251 | + softmax denominator ([batch_size, num_heads, compute_dim, 1]). |
| 252 | + """ |
| 253 | + mask_value = DEFAULT_MASK_VALUE |
| 254 | + if k_scaler is None: |
| 255 | + replicated_in_axes = 4 |
| 256 | + replicated_inputs = (ragged_batch_index, ragged_block_index) |
| 257 | + else: |
| 258 | + replicated_in_axes = 6 |
| 259 | + replicated_inputs = ( |
| 260 | + jnp.squeeze(k_scaler, -1), |
| 261 | + jnp.squeeze(v_scaler, -1), |
| 262 | + ragged_batch_index, |
| 263 | + ragged_block_index, |
| 264 | + ) |
| 265 | + |
| 266 | + with jax.named_scope("ragged_mha_vmap"): |
| 267 | + out, (m, l) = jax.vmap( |
| 268 | + functools.partial( |
| 269 | + ragged_mqa, |
| 270 | + bk=bk, |
| 271 | + mask_value=mask_value, |
| 272 | + normalize_var=normalize_var, |
| 273 | + # out_dtype=out_dtype, |
| 274 | + ), |
| 275 | + in_axes=( |
| 276 | + shard_axis, |
| 277 | + shard_axis, |
| 278 | + shard_axis, |
| 279 | + *([None] * replicated_in_axes), |
| 280 | + ), |
| 281 | + out_axes=shard_axis, |
| 282 | + )(q, k, v, start, end, *replicated_inputs) |
| 283 | + return out, (m, l) |
| 284 | + |
| 285 | + |
| 286 | +def dense_attention(xq, keys, values, k_scaler=None, v_scaler=None, mask=None): |
| 287 | + """The vanilla attention kernel implementation.""" |
| 288 | + |
| 289 | + bsz, _, _, head_dim = xq.shape |
| 290 | + with jax.named_scope("attn_mat1"): |
| 291 | + ## Attention start |
| 292 | + # scores = torch.einsum(jnp.einsum, "ijkl,ikml->ikjm", xq, keys) / math.sqrt(self.head_dim) |
| 293 | + scores = torch.einsum("ikjl,ikml->ikjm", xq, keys) / math.sqrt(head_dim) |
| 294 | + if k_scaler is not None: |
| 295 | + scores = scores * (k_scaler.reshape(bsz, 1, 1, keys.shape[2])) |
| 296 | + if mask is not None: |
| 297 | + # if mask.shape != (1,1,16,16): |
| 298 | + # breakpoint() |
| 299 | + scores = scores + mask # (bs, n_local_heads, seqlen, max_seqlen) |
| 300 | + with jax.named_scope("attn_soft"): |
| 301 | + scores = F.softmax(scores.float(), dim=-1).type_as(xq) |
| 302 | + if v_scaler is not None: |
| 303 | + scores = scores * v_scaler.reshape((bsz, 1, 1, keys.shape[2])) |
| 304 | + |
| 305 | + with jax.named_scope("attn_mat2"): |
| 306 | + # output = torch.einsum( |
| 307 | + # "ikjm,ikml->ikjl", scores, values |
| 308 | + # ) # (bs, n_local_heads, seqlen, head_dim) |
| 309 | + output = torch.einsum("ikjm,ikml->ikjl", scores, values) |
| 310 | + return output |
| 311 | + |
| 312 | + |
| 313 | +class RaggedAttentionKernel: |
| 314 | + """Ragged attention kernel.""" |
| 315 | + |
| 316 | + def __init__(self, env, input_specs, output_specs, sharding_axis): |
| 317 | + self.binded_ragged_mha = functools.partial( |
| 318 | + ragged_mha, bk=env.block_size, shard_axis=sharding_axis |
| 319 | + ) |
| 320 | + self.binded_ragged_mha = shard_map( |
| 321 | + ragged_mha, env.mesh, input_specs, output_specs, check_rep=False |
| 322 | + ) |
| 323 | + self.binded_ragged_mha = jax.jit(self.binded_ragged_mha) |
| 324 | + |
| 325 | + def __call__( |
| 326 | + self, |
| 327 | + xq, |
| 328 | + keys, |
| 329 | + values, |
| 330 | + start, |
| 331 | + end, |
| 332 | + ragged_batch_index, |
| 333 | + ragged_block_index, |
| 334 | + k_scaler=None, |
| 335 | + v_scaler=None, |
| 336 | + ): |
| 337 | + return self.binded_ragged_mha( |
| 338 | + xq, |
| 339 | + keys, |
| 340 | + values, |
| 341 | + start, |
| 342 | + end, |
| 343 | + ragged_batch_index, |
| 344 | + ragged_block_index, |
| 345 | + k_scaler, |
| 346 | + v_scaler, |
| 347 | + ) |
0 commit comments