|
| 1 | +import triton |
| 2 | +import triton.language as tl |
| 3 | + |
| 4 | + |
| 5 | +@triton.jit |
| 6 | +def layer_norm_fwd_fused( |
| 7 | + X, # pointer to the input |
| 8 | + Y, # pointer to the output |
| 9 | + W, # pointer to the weights |
| 10 | + B, # pointer to the biases |
| 11 | + Mean, # pointer to the mean |
| 12 | + Rstd, # pointer to the 1/std |
| 13 | + stride, # how much to increase the pointer when moving by 1 row |
| 14 | + N, # number of columns in X |
| 15 | + eps, # epsilon to avoid division by zero |
| 16 | + BLOCK_SIZE: tl.constexpr, |
| 17 | +): |
| 18 | + # Map the program id to the row of X and Y it should compute. |
| 19 | + row = tl.program_id(0) |
| 20 | + Y += row * stride |
| 21 | + X += row * stride |
| 22 | + # Compute mean |
| 23 | + mean = 0 |
| 24 | + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
| 25 | + for off in range(0, N, BLOCK_SIZE): |
| 26 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 27 | + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) |
| 28 | + _mean += a |
| 29 | + mean = tl.sum(_mean, axis=0) / N |
| 30 | + # Compute variance |
| 31 | + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) |
| 32 | + for off in range(0, N, BLOCK_SIZE): |
| 33 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 34 | + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) |
| 35 | + x = tl.where(cols < N, x - mean, 0.0) |
| 36 | + _var += x * x |
| 37 | + var = tl.sum(_var, axis=0) / N |
| 38 | + rstd = 1 / tl.sqrt(var + eps) |
| 39 | + # Write mean / rstd |
| 40 | + tl.store(Mean + row, mean) |
| 41 | + tl.store(Rstd + row, rstd) |
| 42 | + # Normalize and apply linear transformation |
| 43 | + for off in range(0, N, BLOCK_SIZE): |
| 44 | + cols = off + tl.arange(0, BLOCK_SIZE) |
| 45 | + mask = cols < N |
| 46 | + w = tl.load(W + cols, mask=mask) |
| 47 | + b = tl.load(B + cols, mask=mask) |
| 48 | + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) |
| 49 | + x_hat = (x - mean) * rstd |
| 50 | + y = x_hat * w + b |
| 51 | + # Write output |
| 52 | + tl.store(Y + cols, y, mask=mask) |
0 commit comments