Skip to content

Commit 35bb255

Browse files
ptilletJokeren
andauthored
[TRITON_KERNELS] basic expert parallelism implementation (#8448)
Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Jokeren <robinho364@gmail.com>
1 parent 2875f40 commit 35bb255

File tree

15 files changed

+1700
-153
lines changed

15 files changed

+1700
-153
lines changed

python/triton_kernels/bench/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def legacy_routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot,
2828
combine_indx = sparse_logits.mask_metadata.row_sorted_indx
2929
ragged_batch_metadata = make_ragged_tensor_metadata(sparse_logits.mask_metadata.col_sum, dispatch_indx.shape[0])
3030
gate_scal = sparse_logits.vals.flatten()[combine_indx]
31-
routing_data = RoutingData(gate_scal, ragged_batch_metadata.batch_sizes, n_expts_tot, n_expts_act,
31+
routing_data = RoutingData(gate_scal, ragged_batch_metadata.slice_sizes, n_expts_tot, n_expts_act,
3232
ragged_batch_metadata)
3333
gather_idx = GatherIndx(combine_indx, dispatch_indx)
3434
scatter_idx = ScatterIndx(dispatch_indx, combine_indx)

python/triton_kernels/reduce.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
from dataclasses import dataclass
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
from triton_kernels.numerics_details.mxfp import quantize_mxfp8_fn
6+
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
7+
from triton_kernels.numerics import InFlexData, OutFlexData, MAX_FINITE_FLOAT8E4B8, MAX_FINITE_FLOAT8E4NV, MAX_FINITE_FLOAT8E5
8+
from typing import Optional
9+
import types
10+
import sys
11+
from .specialize import specialize
12+
13+
_kernels = dict()
14+
15+
16+
@dataclass(frozen=True)
17+
class FnSpecs:
18+
name: str
19+
fn: "triton.runtime.jit.JITFunction"
20+
fn_arg_names: tuple[str]
21+
fn_arg_do_not_specialize: tuple[str] = tuple()
22+
23+
@staticmethod
24+
def default():
25+
return FnSpecs("dflt", None, tuple())
26+
27+
28+
@dataclass(frozen=True)
29+
class PostprocessFn:
30+
specs: FnSpecs = FnSpecs.default()
31+
fn_args: tuple[object] = tuple()
32+
33+
34+
def get_kernels(fn_specs: FnSpecs = FnSpecs.default()):
35+
global _kernels
36+
key = (fn_specs.name, )
37+
if key in _kernels:
38+
return _kernels[key]
39+
spec_constants = {"POSTPROCESS_FN": fn_specs.fn}
40+
spec_tuples = {"postprocess_fn_args": fn_specs.fn_arg_names}
41+
do_not_specialize = fn_specs.fn_arg_do_not_specialize
42+
module = types.ModuleType(f"reduce{'_'.join(key)}")
43+
sys.modules[module.__name__] = module
44+
module._reduce = specialize(_reduce, module, spec_constants, spec_tuples, do_not_specialize=do_not_specialize)
45+
_kernels[key] = module
46+
return module
47+
48+
49+
@triton.jit
50+
def _reduce(X, stride_xr, stride_x0, stride_x1, # x tensor (input)
51+
XMx, stride_xmxr, stride_xmx0, stride_xmx1, # x mx scale
52+
Y, stride_y0, stride_y1, # y tensor (output)
53+
YMx, stride_ymx0, stride_ymx1, # y mx scale
54+
Mask, stride_mr, stride_m0, stride_m1, # mask tensor
55+
Scale, stride_sr, stride_s0, stride_s1, # scale tensor
56+
K, S0, S1, # shape (K = reduction dim; S0, S1 = output dims)
57+
POSTPROCESS_FN: tl.constexpr, postprocess_fn_args, XFlex, # x flex (global) scale
58+
YFlexExpected, YFlexActual, YFlexChecksum, Y_FLEX_SATURATE_INF: tl.constexpr, # y flex (global) scale
59+
IS_MASK_NONE: tl.constexpr, #
60+
BROADCAST_R: tl.constexpr, #
61+
BROADCAST_S0: tl.constexpr, #
62+
BROADCAST_S1: tl.constexpr, #
63+
IS_SCALE_NONE: tl.constexpr, #
64+
SCALE_BROADCAST_R: tl.constexpr, #
65+
SCALE_BROADCAST_S0: tl.constexpr, #
66+
SCALE_BROADCAST_S1: tl.constexpr, #
67+
BLOCK_S0: tl.constexpr, #
68+
BLOCK_S1: tl.constexpr, #
69+
):
70+
pid_s0 = tl.program_id(0)
71+
pid_s1 = tl.program_id(1)
72+
tl.static_assert(BLOCK_S1 % 32 == 0)
73+
BLOCK_SMX1: tl.constexpr = BLOCK_S1 // 32
74+
offs_s0 = pid_s0 * BLOCK_S0 + tl.arange(0, BLOCK_S0)
75+
offs_s1 = pid_s1 * BLOCK_S1 + tl.arange(0, BLOCK_S1)
76+
offs_smx1 = pid_s1 * BLOCK_SMX1 + tl.arange(0, BLOCK_SMX1)
77+
valid_s0 = offs_s0 < S0
78+
valid_s1 = offs_s1 < S1
79+
valid_smx1 = offs_smx1 < tl.cdiv(S1, 32)
80+
y = tl.zeros((BLOCK_S0, BLOCK_S1), dtype=tl.float32)
81+
x_flex_scale = load_scale(XFlex)
82+
for k in tl.range(0, K, num_stages=2):
83+
x_ptrs = X + k * stride_xr + offs_s0[:, None] * stride_x0 + offs_s1[None, :] * stride_x1
84+
x = tl.load(x_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=0.0)
85+
x = x.to(tl.float32)
86+
if XMx is not None:
87+
xmx_ptrs = XMx + k * stride_xmxr + offs_s0[:, None] * stride_xmx0 + offs_smx1[None, :] * stride_xmx1
88+
xmx = tl.load(xmx_ptrs, mask=valid_s0[:, None] & valid_smx1[None, :], other=0.0)
89+
xmx = (xmx.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
90+
x = (xmx[:, :, None] * x.reshape([BLOCK_S0, BLOCK_S1 // 32, 32])).reshape([BLOCK_S0, BLOCK_S1])
91+
x = x * x_flex_scale
92+
if not IS_SCALE_NONE:
93+
k_term_s = 0 if SCALE_BROADCAST_R else (k * stride_sr)
94+
s0_term_s = 0 if SCALE_BROADCAST_S0 else (offs_s0[:, None] * stride_s0)
95+
s1_term_s = 0 if SCALE_BROADCAST_S1 else (offs_s1[None, :] * stride_s1)
96+
s_ptrs = Scale + k_term_s + s0_term_s + s1_term_s
97+
s = tl.load(s_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1)
98+
x = x * s
99+
if not IS_MASK_NONE:
100+
k_term = 0 if BROADCAST_R else (k * stride_mr)
101+
s0_term = 0 if BROADCAST_S0 else (offs_s0[:, None] * stride_m0)
102+
s1_term = 0 if BROADCAST_S1 else (offs_s1[None, :] * stride_m1)
103+
m_ptrs = Mask + k_term + s0_term + s1_term
104+
m = tl.load(m_ptrs, mask=valid_s0[:, None] & valid_s1[None, :], other=1)
105+
x = tl.where(m != 0, x, 0.0)
106+
y += x
107+
if POSTPROCESS_FN is not None:
108+
y = POSTPROCESS_FN(y, *postprocess_fn_args)
109+
y = float_to_flex(y, YFlexExpected, YFlexActual, YFlexChecksum, None, Y, Y_FLEX_SATURATE_INF)
110+
y_ptrs = Y + offs_s0[:, None] * stride_y0 + offs_s1[None, :] * stride_y1
111+
if YMx is not None:
112+
y, y_scale = quantize_mxfp8_fn(y, valid_s1[None, :])
113+
y_mx_ptrs = YMx + offs_s0[:, None] * stride_ymx0 + offs_smx1[None, :] * stride_ymx1
114+
tl.store(y_mx_ptrs, y_scale, mask=valid_s0[:, None] & valid_smx1[None, :])
115+
tl.store(y_ptrs, y, mask=valid_s0[:, None] & valid_s1[None, :])
116+
117+
118+
def reduce(
119+
x: torch.Tensor,
120+
dim: int,
121+
mask: Optional[torch.Tensor] = None,
122+
scale: Optional[torch.Tensor] = None,
123+
x_mxscale: Optional[torch.Tensor] = None,
124+
x_flex: Optional[InFlexData] = InFlexData(),
125+
y_flex: Optional[OutFlexData] = OutFlexData(),
126+
y_flex_saturate_inf: bool = False,
127+
postprocess_fn: Optional[PostprocessFn] = None,
128+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
129+
"""
130+
Performs a reduction over the specified dimension of the input tensor,
131+
optionally multiplied by `scale` and ignoring masked elements.
132+
133+
Arguments:
134+
- x: Tensor
135+
input tensor to reduce.
136+
- dim: int
137+
dimension along which `x` should be reduce.
138+
- mask: Optional[torch.Tensor]
139+
integer mask of the same shape as `x` (or broadcastable to it).
140+
entries that are `0` are ignored in the reduction.
141+
if `mask is None`, all elements are included.
142+
- scale: Optional[torch.Tensor]
143+
scale factors of the same shape as `x` (or broadcastable to it).
144+
the reduction is performed over `x * scale`. If `scale is None`,
145+
a value of 1 is used everywhere.
146+
147+
Returns:
148+
- output: torch.Tensor
149+
The reduced tensor with `dim` removed.
150+
"""
151+
if x.ndim != 3:
152+
raise NotImplementedError("reduce only supports 3D inputs in this implementation")
153+
if dim < 0:
154+
dim += x.ndim
155+
if dim not in (0, 1, 2):
156+
raise ValueError("dim must be in {0,1,2}")
157+
if x_mxscale is not None:
158+
if dim == 2:
159+
raise ValueError("reduction over the micro-scaled dimension not supported")
160+
assert x.shape[:-2] == x_mxscale.shape[:-2]
161+
assert triton.cdiv(x.shape[-1], 32) * 32 == x_mxscale.shape[-1] * 32
162+
assert dim != -1
163+
# assert not y_flex.is_per_batch
164+
if postprocess_fn is None:
165+
postprocess_fn = PostprocessFn()
166+
if y_flex is None:
167+
y_flex = OutFlexData()
168+
if x_flex is None:
169+
x_flex = InFlexData()
170+
# input shapes
171+
dims = (0, 1, 2)
172+
nonred = tuple(d for d in dims if d != dim)
173+
S0, S1 = x.shape[nonred[0]], x.shape[nonred[1]]
174+
y = torch.empty((S0, S1), device=x.device, dtype=x.dtype)
175+
y_mxscale = None
176+
if x_mxscale is not None:
177+
y_mxscale = torch.empty((S0, triton.cdiv(S1, 32)), device=x.device, dtype=x_mxscale.dtype)
178+
# Strides for X along reduced and non-reduced dims
179+
stride_xr = x.stride(dim)
180+
stride_x0 = x.stride(nonred[0])
181+
stride_x1 = x.stride(nonred[1])
182+
# Strides for X mx scales
183+
stride_xmxr = None if x_mxscale is None else x_mxscale.stride(dim)
184+
stride_xmx0 = None if x_mxscale is None else x_mxscale.stride(nonred[0])
185+
stride_xmx1 = None if x_mxscale is None else x_mxscale.stride(nonred[1])
186+
# Strides for Y mx scales
187+
stride_ymx0 = None if y_mxscale is None else y_mxscale.stride(0)
188+
stride_ymx1 = None if y_mxscale is None else y_mxscale.stride(1)
189+
# Mask strides (broadcast allowed via stride 0)
190+
if mask is not None:
191+
mstr0, mstr1, mstr2 = mask.stride()
192+
stride_mr = (mstr0 if dim == 0 else (mstr1 if dim == 1 else mstr2))
193+
stride_m0 = (mstr0 if nonred[0] == 0 else (mstr1 if nonred[0] == 1 else mstr2))
194+
stride_m1 = (mstr0 if nonred[1] == 0 else (mstr1 if nonred[1] == 1 else mstr2))
195+
else:
196+
stride_mr = stride_m0 = stride_m1 = 0
197+
# Scale strides (broadcast allowed via stride 0)
198+
if scale is not None:
199+
sstr0, sstr1, sstr2 = scale.stride()
200+
stride_sr = (sstr0 if dim == 0 else (sstr1 if dim == 1 else sstr2))
201+
stride_s0 = (sstr0 if nonred[0] == 0 else (sstr1 if nonred[0] == 1 else sstr2))
202+
stride_s1 = (sstr0 if nonred[1] == 0 else (sstr1 if nonred[1] == 1 else sstr2))
203+
else:
204+
stride_sr = stride_s0 = stride_s1 = 0
205+
K = x.shape[dim]
206+
# Always use the 2D tiled kernel with constexpr metaprogramming for mask broadcasting
207+
BLOCK_S0 = 64
208+
BLOCK_S1 = 128
209+
grid = (triton.cdiv(S0, BLOCK_S0), triton.cdiv(S1, BLOCK_S1))
210+
mask_arg = mask if mask is not None else x
211+
scale_arg = scale if scale is not None else x
212+
reduce_kernel = get_kernels(postprocess_fn.specs)._reduce
213+
reduce_kernel[grid](
214+
x, stride_xr, stride_x0, stride_x1, #
215+
x_mxscale, stride_xmxr, stride_xmx0, stride_xmx1, #
216+
y, y.stride(0), y.stride(1), #
217+
y_mxscale, stride_ymx0, stride_ymx1, #
218+
mask_arg, stride_mr, stride_m0, stride_m1, #
219+
scale_arg, stride_sr, stride_s0, stride_s1, #
220+
K, S0, S1, #
221+
*postprocess_fn.fn_args, x_flex.scale, y_flex.expected_scale, y_flex.actual_scale, y_flex.checksum_scale,
222+
y_flex_saturate_inf, #
223+
IS_MASK_NONE=(mask is None), #
224+
BROADCAST_R=(stride_mr == 0), #
225+
BROADCAST_S0=(stride_m0 == 0), #
226+
BROADCAST_S1=(stride_m1 == 0), #
227+
IS_SCALE_NONE=(scale is None), #
228+
SCALE_BROADCAST_R=(stride_sr == 0), #
229+
SCALE_BROADCAST_S0=(stride_s0 == 0), #
230+
SCALE_BROADCAST_S1=(stride_s1 == 0), #
231+
BLOCK_S0=BLOCK_S0, #
232+
BLOCK_S1=BLOCK_S1, #
233+
num_warps=4 #
234+
)
235+
return y, y_mxscale
236+
237+
238+
def compute_actual_scale(x, dtype, per_batch_scale=False):
239+
max_finite = {
240+
torch.float8_e5m2: MAX_FINITE_FLOAT8E5,
241+
torch.float8_e4m3fn: MAX_FINITE_FLOAT8E4NV,
242+
torch.float8_e4m3fnuz: MAX_FINITE_FLOAT8E4B8,
243+
}[dtype]
244+
maxvals = x.abs().amax(dim=tuple(range(1, x.ndim))) if per_batch_scale else x.abs().max()
245+
return maxvals / max_finite
246+
247+
248+
def reduce_torch(x: torch.Tensor, dim: int, mask: Optional[torch.Tensor] = None, #
249+
scale: Optional[torch.Tensor] = None, #
250+
x_mxscale: Optional[torch.Tensor] = None, #
251+
x_flex: Optional[InFlexData] = InFlexData(), y_flex: Optional[OutFlexData] = OutFlexData(),
252+
y_flex_saturate_inf: bool = False, postprocess_fn: Optional[callable] = None):
253+
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp_torch, upcast_from_mxfp_torch
254+
x_dtype = x.dtype
255+
# upcast input
256+
if x_mxscale is not None:
257+
x = upcast_from_mxfp_torch(x, x_mxscale, torch.float32, axis=-1)
258+
x = x.to(torch.float32)
259+
if x_flex is not None:
260+
x *= x_flex.scale
261+
# upcast scale
262+
if scale is None:
263+
scale = torch.ones(1, dtype=torch.float32, device=x.device)
264+
scale = scale.to(torch.float32)
265+
# initialize mask
266+
if mask is None:
267+
mask = torch.ones(1, dtype=torch.bool, device=x.device)
268+
mask = mask.to(torch.bool)
269+
ret = torch.where(mask, x * scale, 0).sum(dim=dim)
270+
if postprocess_fn is not None:
271+
ret = postprocess_fn(ret)
272+
if y_flex is not None:
273+
y_flex.actual_scale.copy_(compute_actual_scale(ret, x_dtype, y_flex.is_per_batch))
274+
ret = (ret / y_flex.expected_scale).to(x_dtype)
275+
# downcast output
276+
ret_mxscale = None
277+
if x_mxscale is not None:
278+
assert y_flex is None
279+
ret, ret_mxscale = downcast_to_mxfp_torch(ret, torch.float8_e4m3fn, axis=-1)
280+
return ret.to(x_dtype), ret_mxscale

0 commit comments

Comments
 (0)