|
| 1 | +from dataclasses import dataclass |
| 2 | +import torch |
| 3 | +import triton |
| 4 | +import triton.language as tl |
| 5 | +from triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn |
| 6 | +from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale |
| 7 | +from triton_kernels.numerics import InFlexData, OutFlexData, MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5 |
| 8 | +from typing import Optional |
| 9 | +import types |
| 10 | +import sys |
| 11 | +from .specialize import specialize |
| 12 | + |
| 13 | +_kernels = dict() |
| 14 | + |
| 15 | + |
| 16 | +@dataclass(frozen=True) |
| 17 | +class FnSpecs: |
| 18 | + name: str |
| 19 | + fn: "triton.runtime.jit.JITFunction" |
| 20 | + fn_arg_names: tuple[str] |
| 21 | + fn_arg_do_not_specialize: tuple[str] = tuple() |
| 22 | + |
| 23 | + @staticmethod |
| 24 | + def default(): |
| 25 | + return FnSpecs("dflt", None, tuple()) |
| 26 | + |
| 27 | + |
| 28 | +@dataclass(frozen=True) |
| 29 | +class PostprocessFn: |
| 30 | + specs: FnSpecs = FnSpecs.default() |
| 31 | + fn_args: tuple[object] = tuple() |
| 32 | + |
| 33 | + |
| 34 | +def get_kernels(fn_specs: FnSpecs = FnSpecs.default()): |
| 35 | + global _kernels |
| 36 | + key = (fn_specs.name, ) |
| 37 | + if key in _kernels: |
| 38 | + return _kernels[key] |
| 39 | + spec_constants = {"POSTPROCESS_FN": fn_specs.fn} |
| 40 | + spec_tuples = {"postprocess_fn_args": fn_specs.fn_arg_names} |
| 41 | + do_not_specialize = fn_specs.fn_arg_do_not_specialize |
| 42 | + module = types.ModuleType(f"reduce{'_'.join(key)}") |
| 43 | + sys.modules[module.__name__] = module |
| 44 | + module._reduce = specialize(_reduce, module, spec_constants, spec_tuples, do_not_specialize=do_not_specialize) |
| 45 | + _kernels[key] = module |
| 46 | + return module |
| 47 | + |
| 48 | + |
| 49 | +@triton.jit |
| 50 | +def _reduce(X, stride_xr, stride_x0, stride_x1, # x tensor (input) |
| 51 | + XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale |
| 52 | + Y, stride_y0, stride_y1, # y tensor (output) |
| 53 | + YMx, stride_ymx0, stride_ymx1, # y mx scale |
| 54 | + Mask, stride_mr, stride_m0, stride_m1, # mask tensor |
| 55 | + Scale, stride_sr, stride_s0, stride_s1, # scale tensor |
| 56 | + K, S0, S1, # shape (K = reduction dim; S0, S1 = output dims) |
| 57 | + POSTPROCESS_FN: tl.constexpr, postprocess_fn_args, XFlex, # x flex (global) scale |
| 58 | + YFlexExpected, YFlexActual, YFlexChecksum, Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale |
| 59 | + IS_MASK_NONE: tl.constexpr, # |
| 60 | + BROADCAST_R: tl.constexpr, # |
| 61 | + BROADCAST_S0: tl.constexpr, # |
| 62 | + BROADCAST_S1: tl.constexpr, # |
| 63 | + IS_SCALE_NONE: tl.constexpr, # |
| 64 | + SCALE_BROADCAST_R: tl.constexpr, # |
| 65 | + SCALE_BROADCAST_S0: tl.constexpr, # |
| 66 | + SCALE_BROADCAST_S1: tl.constexpr, # |
| 67 | + BLOCK_S0: tl.constexpr, # |
| 68 | + BLOCK_S1: tl.constexpr, # |
| 69 | + ): |
| 70 | + pid_s0 = tl.program_id(0) |
| 71 | + pid_s1 = tl.program_id(1) |
| 72 | + tl.static_assert(BLOCK_S1 % 32 == 0) |
| 73 | + BLOCK_SMX1: tl.constexpr = BLOCK_S1 // 32 |
| 74 | + offs_s0 = pid_s0 * BLOCK_S0 + tl.arange(0, BLOCK_S0) |
| 75 | + offs_s1 = pid_s1 * BLOCK_S1 + tl.arange(0, BLOCK_S1) |
| 76 | + offs_smx1 = pid_s1 * BLOCK_SMX1 + tl.arange(0, BLOCK_SMX1) |
| 77 | + valid_s0 = offs_s0 < S0 |
| 78 | + valid_s1 = offs_s1 < S1 |
| 79 | + valid_smx1 = offs_smx1 < tl.cdiv(S1, 32) |
| 80 | + y = tl.zeros((BLOCK_S0, BLOCK_S1), dtype=tl.float32) |
| 81 | + x_flex_scale = load_scale(XFlex) |
| 82 | + for k in tl.range(0, K, num_stages=2): |
| 83 | + x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_s1[None, :] * stride_x1 |
| 84 | + x = tl.load(x_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=0.0) |
| 85 | + x = x.to(tl.float32) |
| 86 | + if XMx is not None: |
| 87 | + xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_smx1[None, :] * stride_xmx1 |
| 88 | + xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_smx1[None, :], other=0.0) |
| 89 | + xmx = (xmx.to(tl.uint32) << 23).to(tl.float32, bitcast=True) |
| 90 | + x = (xmx[:, :, None] * x.reshape([BLOCK_S0, BLOCK_S1 // 32, 32])).reshape([BLOCK_S0, BLOCK_S1]) |
| 91 | + x = x * x_flex_scale |
| 92 | + if not IS_SCALE_NONE: |
| 93 | + k_term_s = 0 if SCALE_BROADCAST_R else (k * stride_sr) |
| 94 | + s0_term_s = 0 if SCALE_BROADCAST_S0 else (offs_s0[:, None] * stride_s0) |
| 95 | + s1_term_s = 0 if SCALE_BROADCAST_S1 else (offs_s1[None, :] * stride_s1) |
| 96 | + s_ptrs = Scale + k_term_s + s0_term_s + s1_term_s |
| 97 | + s = tl.load(s_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1) |
| 98 | + x = x * s |
| 99 | + if not IS_MASK_NONE: |
| 100 | + k_term = 0 if BROADCAST_R else (k * stride_mr) |
| 101 | + s0_term = 0 if BROADCAST_S0 else (offs_s0[:, None] * stride_m0) |
| 102 | + s1_term = 0 if BROADCAST_S1 else (offs_s1[None, :] * stride_m1) |
| 103 | + m_ptrs = Mask + k_term + s0_term + s1_term |
| 104 | + m = tl.load(m_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1) |
| 105 | + x = tl.where(m != 0, x, 0.0) |
| 106 | + y += x |
| 107 | + if POSTPROCESS_FN is not None: |
| 108 | + y = POSTPROCESS_FN(y, *postprocess_fn_args) |
| 109 | + y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF) |
| 110 | + y_ptrs = Y + offs_s0[:, None] * stride_y0 + offs_s1[None, :] * stride_y1 |
| 111 | + if YMx is not None: |
| 112 | + y, y_scale = quantize_mxfp8_fn(y, valid_s1[None, :]) |
| 113 | + y_mx_ptrs = YMx + offs_s0[:, None] * stride_ymx0 + offs_smx1[None, :] * stride_ymx1 |
| 114 | + tl.store(y_mx_ptrs, y_scale, mask=valid_s0[:, None] & valid_smx1[None, :]) |
| 115 | + tl.store(y_ptrs, y, mask=valid_s0[:, None] & valid_s1[None, :]) |
| 116 | + |
| 117 | + |
| 118 | +def reduce( |
| 119 | + x: torch.Tensor, |
| 120 | + dim: int, |
| 121 | + mask: Optional[torch.Tensor] = None, |
| 122 | + scale: Optional[torch.Tensor] = None, |
| 123 | + x_mxscale: Optional[torch.Tensor] = None, |
| 124 | + x_flex: Optional[InFlexData] = InFlexData(), |
| 125 | + y_flex: Optional[OutFlexData] = OutFlexData(), |
| 126 | + y_flex_saturate_inf: bool = False, |
| 127 | + postprocess_fn: Optional[PostprocessFn] = None, |
| 128 | +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
| 129 | + """ |
| 130 | + Performs a reduction over the specified dimension of the input tensor, |
| 131 | + optionally multiplied by `scale` and ignoring masked elements. |
| 132 | +
|
| 133 | + Arguments: |
| 134 | + - x: Tensor |
| 135 | + input tensor to reduce. |
| 136 | + - dim: int |
| 137 | + dimension along which `x` should be reduce. |
| 138 | + - mask: Optional[torch.Tensor] |
| 139 | + integer mask of the same shape as `x` (or broadcastable to it). |
| 140 | + entries that are `0` are ignored in the reduction. |
| 141 | + if `mask is None`, all elements are included. |
| 142 | + - scale: Optional[torch.Tensor] |
| 143 | + scale factors of the same shape as `x` (or broadcastable to it). |
| 144 | + the reduction is performed over `x * scale`. If `scale is None`, |
| 145 | + a value of 1 is used everywhere. |
| 146 | +
|
| 147 | + Returns: |
| 148 | + - output: torch.Tensor |
| 149 | + The reduced tensor with `dim` removed. |
| 150 | + """ |
| 151 | + if x.ndim != 3: |
| 152 | + raise NotImplementedError("reduce only supports 3D inputs in this implementation") |
| 153 | + if dim < 0: |
| 154 | + dim += x.ndim |
| 155 | + if dim not in (0, 1, 2): |
| 156 | + raise ValueError("dim must be in {0,1,2}") |
| 157 | + if x_mxscale is not None: |
| 158 | + if dim == 2: |
| 159 | + raise ValueError("reduction over the micro-scaled dimension not supported") |
| 160 | + assert x.shape[:-2] == x_mxscale.shape[:-2] |
| 161 | + assert triton.cdiv(x.shape[-1], 32) * 32 == x_mxscale.shape[-1] * 32 |
| 162 | + assert dim != -1 |
| 163 | + # assert not y_flex.is_per_batch |
| 164 | + if postprocess_fn is None: |
| 165 | + postprocess_fn = PostprocessFn() |
| 166 | + if y_flex is None: |
| 167 | + y_flex = OutFlexData() |
| 168 | + if x_flex is None: |
| 169 | + x_flex = InFlexData() |
| 170 | + # input shapes |
| 171 | + dims = (0, 1, 2) |
| 172 | + nonred = tuple(d for d in dims if d != dim) |
| 173 | + S0, S1 = x.shape[nonred[0]], x.shape[nonred[1]] |
| 174 | + y = torch.empty((S0, S1), device=x.device, dtype=x.dtype) |
| 175 | + y_mxscale = None |
| 176 | + if x_mxscale is not None: |
| 177 | + y_mxscale = torch.empty((S0, triton.cdiv(S1, 32)), device=x.device, dtype=x_mxscale.dtype) |
| 178 | + # Strides for X along reduced and non-reduced dims |
| 179 | + stride_xr = x.stride(dim) |
| 180 | + stride_x0 = x.stride(nonred[0]) |
| 181 | + stride_x1 = x.stride(nonred[1]) |
| 182 | + # Strides for X mx scales |
| 183 | + stride_xmxr = None if x_mxscale is None else x_mxscale.stride(dim) |
| 184 | + stride_xmx0 = None if x_mxscale is None else x_mxscale.stride(nonred[0]) |
| 185 | + stride_xmx1 = None if x_mxscale is None else x_mxscale.stride(nonred[1]) |
| 186 | + # Strides for Y mx scales |
| 187 | + stride_ymx0 = None if y_mxscale is None else y_mxscale.stride(0) |
| 188 | + stride_ymx1 = None if y_mxscale is None else y_mxscale.stride(1) |
| 189 | + # Mask strides (broadcast allowed via stride 0) |
| 190 | + if mask is not None: |
| 191 | + mstr0, mstr1, mstr2 = mask.stride() |
| 192 | + stride_mr = (mstr0 if dim == 0 else (mstr1 if dim == 1 else mstr2)) |
| 193 | + stride_m0 = (mstr0 if nonred[0] == 0 else (mstr1 if nonred[0] == 1 else mstr2)) |
| 194 | + stride_m1 = (mstr0 if nonred[1] == 0 else (mstr1 if nonred[1] == 1 else mstr2)) |
| 195 | + else: |
| 196 | + stride_mr = stride_m0 = stride_m1 = 0 |
| 197 | + # Scale strides (broadcast allowed via stride 0) |
| 198 | + if scale is not None: |
| 199 | + sstr0, sstr1, sstr2 = scale.stride() |
| 200 | + stride_sr = (sstr0 if dim == 0 else (sstr1 if dim == 1 else sstr2)) |
| 201 | + stride_s0 = (sstr0 if nonred[0] == 0 else (sstr1 if nonred[0] == 1 else sstr2)) |
| 202 | + stride_s1 = (sstr0 if nonred[1] == 0 else (sstr1 if nonred[1] == 1 else sstr2)) |
| 203 | + else: |
| 204 | + stride_sr = stride_s0 = stride_s1 = 0 |
| 205 | + K = x.shape[dim] |
| 206 | + # Always use the 2D tiled kernel with constexpr metaprogramming for mask broadcasting |
| 207 | + BLOCK_S0 = 64 |
| 208 | + BLOCK_S1 = 128 |
| 209 | + grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(S1, BLOCK_S1)) |
| 210 | + mask_arg = mask if mask is not None else x |
| 211 | + scale_arg = scale if scale is not None else x |
| 212 | + reduce_kernel = get_kernels(postprocess_fn.specs)._reduce |
| 213 | + reduce_kernel[grid]( |
| 214 | + x, stride_xr, stride_x0, stride_x1, # |
| 215 | + x_mxscale, stride_xmxr, stride_xmx0, stride_xmx1, # |
| 216 | + y, y.stride(0), y.stride(1), # |
| 217 | + y_mxscale, stride_ymx0, stride_ymx1, # |
| 218 | + mask_arg, stride_mr, stride_m0, stride_m1, # |
| 219 | + scale_arg, stride_sr, stride_s0, stride_s1, # |
| 220 | + K, S0, S1, # |
| 221 | + *postprocess_fn.fn_args, x_flex.scale, y_flex.expected_scale, y_flex.actual_scale, y_flex.checksum_scale, |
| 222 | + y_flex_saturate_inf, # |
| 223 | + IS_MASK_NONE=(mask is None), # |
| 224 | + BROADCAST_R=(stride_mr == 0), # |
| 225 | + BROADCAST_S0=(stride_m0 == 0), # |
| 226 | + BROADCAST_S1=(stride_m1 == 0), # |
| 227 | + IS_SCALE_NONE=(scale is None), # |
| 228 | + SCALE_BROADCAST_R=(stride_sr == 0), # |
| 229 | + SCALE_BROADCAST_S0=(stride_s0 == 0), # |
| 230 | + SCALE_BROADCAST_S1=(stride_s1 == 0), # |
| 231 | + BLOCK_S0=BLOCK_S0, # |
| 232 | + BLOCK_S1=BLOCK_S1, # |
| 233 | + num_warps=4 # |
| 234 | + ) |
| 235 | + return y, y_mxscale |
| 236 | + |
| 237 | + |
| 238 | +def compute_actual_scale(x, dtype, per_batch_scale=False): |
| 239 | + max_finite = { |
| 240 | + torch.float8_e5m2: MAX_FINITE_FLOAT8E5, |
| 241 | + torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV, |
| 242 | + torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8, |
| 243 | + }[dtype] |
| 244 | + maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max() |
| 245 | + return maxvals / max_finite |
| 246 | + |
| 247 | + |
| 248 | +def reduce_torch(x: torch.Tensor, dim: int, mask: Optional[torch.Tensor] = None, # |
| 249 | + scale: Optional[torch.Tensor] = None, # |
| 250 | + x_mxscale: Optional[torch.Tensor] = None, # |
| 251 | + x_flex: Optional[InFlexData] = InFlexData(), y_flex: Optional[OutFlexData] = OutFlexData(), |
| 252 | + y_flex_saturate_inf: bool = False, postprocess_fn: Optional[callable] = None): |
| 253 | + from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch, upcast_from_mxfp_torch |
| 254 | + x_dtype = x.dtype |
| 255 | + # upcast input |
| 256 | + if x_mxscale is not None: |
| 257 | + x = upcast_from_mxfp_torch(x, x_mxscale, torch.float32, axis=-1) |
| 258 | + x = x.to(torch.float32) |
| 259 | + if x_flex is not None: |
| 260 | + x *= x_flex.scale |
| 261 | + # upcast scale |
| 262 | + if scale is None: |
| 263 | + scale = torch.ones(1, dtype=torch.float32, device=x.device) |
| 264 | + scale = scale.to(torch.float32) |
| 265 | + # initialize mask |
| 266 | + if mask is None: |
| 267 | + mask = torch.ones(1, dtype=torch.bool, device=x.device) |
| 268 | + mask = mask.to(torch.bool) |
| 269 | + ret = torch.where(mask, x * scale, 0).sum(dim=dim) |
| 270 | + if postprocess_fn is not None: |
| 271 | + ret = postprocess_fn(ret) |
| 272 | + if y_flex is not None: |
| 273 | + y_flex.actual_scale.copy_(compute_actual_scale(ret, x_dtype, y_flex.is_per_batch)) |
| 274 | + ret = (ret / y_flex.expected_scale).to(x_dtype) |
| 275 | + # downcast output |
| 276 | + ret_mxscale = None |
| 277 | + if x_mxscale is not None: |
| 278 | + assert y_flex is None |
| 279 | + ret, ret_mxscale = downcast_to_mxfp_torch(ret, torch.float8_e4m3fn, axis=-1) |
| 280 | + return ret.to(x_dtype), ret_mxscale |
0 commit comments