Skip to content

Commit 357fa84

Browse files
committed
test: layer_norm + libdevice
1 parent e5ea2f9 commit 357fa84

File tree

5 files changed

+161
-2
lines changed

5 files changed

+161
-2
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
using PythonCall, Reactant, Test
2+
3+
pyimport("sys").path.append(@__DIR__)
4+
5+
layer_norm_kernel = pyimport("layer_norm").layer_norm_fwd_fused
6+
7+
function layer_norm_triton(
8+
x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T}
9+
) where {T}
10+
x_transposed = permutedims(x, (2, 1)) # match python array layout
11+
y = similar(x_transposed)
12+
M, N = size(x_transposed)
13+
mean = similar(x_transposed, Float32, M)
14+
rstd = similar(x_transposed, Float32, M)
15+
16+
max_fused_size = 65536 ÷ sizeof(T)
17+
block_size = min(max_fused_size, nextpow(2, N))
18+
19+
if N > block_size
20+
throw(ArgumentError("This layer norm doesn't support feature dim >= 64KB."))
21+
end
22+
23+
num_warps = min(max(block_size ÷ 256, 1), 8)
24+
25+
layer_norm_kernel(
26+
x_transposed,
27+
y,
28+
weight,
29+
bias,
30+
mean,
31+
rstd,
32+
Reactant.rowmajor_stride(x_transposed, 1),
33+
N,
34+
1.0f-5,
35+
block_size;
36+
num_warps=num_warps,
37+
num_ctas=1,
38+
grid=(M,),
39+
blocks=(block_size,),
40+
)
41+
42+
return permutedims(y, (2, 1)), mean, rstd
43+
end
44+
45+
function layer_norm_naive(
46+
x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T}
47+
) where {T}
48+
mean = sum(x; dims=1) ./ size(x, 1)
49+
rstd = 1 ./ sqrt.(sum(abs2, x .- mean; dims=1) ./ size(x, 1) .+ 1e-5)
50+
x_hat = (x .- mean) .* rstd
51+
return x_hat .* weight .+ bias, vec(mean), vec(rstd)
52+
end
53+
54+
@testset "fused_layer_norm" begin
55+
if RunningOnCUDA
56+
x_ra = Reactant.to_rarray(rand(Float32, 256, 2056))
57+
weight_ra = Reactant.to_rarray(rand(Float32, 256))
58+
bias_ra = Reactant.to_rarray(rand(Float32, 256))
59+
60+
y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra)
61+
y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra)
62+
63+
@test y_ra1 y_ra2
64+
@test mean_ra1 mean_ra2
65+
@test rstd_ra1 rstd_ra2
66+
end
67+
end
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import triton
2+
import triton.language as tl
3+
4+
5+
@triton.jit
6+
def layer_norm_fwd_fused(
7+
X, # pointer to the input
8+
Y, # pointer to the output
9+
W, # pointer to the weights
10+
B, # pointer to the biases
11+
Mean, # pointer to the mean
12+
Rstd, # pointer to the 1/std
13+
stride, # how much to increase the pointer when moving by 1 row
14+
N, # number of columns in X
15+
eps, # epsilon to avoid division by zero
16+
BLOCK_SIZE: tl.constexpr,
17+
):
18+
# Map the program id to the row of X and Y it should compute.
19+
row = tl.program_id(0)
20+
Y += row * stride
21+
X += row * stride
22+
# Compute mean
23+
mean = 0
24+
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
25+
for off in range(0, N, BLOCK_SIZE):
26+
cols = off + tl.arange(0, BLOCK_SIZE)
27+
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
28+
_mean += a
29+
mean = tl.sum(_mean, axis=0) / N
30+
# Compute variance
31+
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
32+
for off in range(0, N, BLOCK_SIZE):
33+
cols = off + tl.arange(0, BLOCK_SIZE)
34+
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
35+
x = tl.where(cols < N, x - mean, 0.0)
36+
_var += x * x
37+
var = tl.sum(_var, axis=0) / N
38+
rstd = 1 / tl.sqrt(var + eps)
39+
# Write mean / rstd
40+
tl.store(Mean + row, mean)
41+
tl.store(Rstd + row, rstd)
42+
# Normalize and apply linear transformation
43+
for off in range(0, N, BLOCK_SIZE):
44+
cols = off + tl.arange(0, BLOCK_SIZE)
45+
mask = cols < N
46+
w = tl.load(W + cols, mask=mask)
47+
b = tl.load(B + cols, mask=mask)
48+
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
49+
x_hat = (x - mean) * rstd
50+
y = x_hat * w + b
51+
# Write output
52+
tl.store(Y + cols, y, mask=mask)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using PythonCall, Reactant, Test
2+
3+
pyimport("sys").path.append(@__DIR__)
4+
5+
asin_kernel = pyimport("libdevice").asin_kernel
6+
7+
const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA")
8+
9+
function asin_triton(x::AbstractVector{T}) where {T}
10+
out = similar(x)
11+
asin_kernel(x, out, length(x), 1024; grid=(cld(length(x), 1024),), blocks=(1024,))
12+
return out
13+
end
14+
15+
@testset "libdevice asin" begin
16+
if RunningOnCUDA
17+
x_ra = Reactant.to_rarray(rand(Float32, 2096))
18+
19+
@test @jit(asin_triton(x_ra)) @jit(asin.(x_ra))
20+
end
21+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import triton
2+
import triton.language as tl
3+
from triton.language.extra import libdevice
4+
5+
6+
@triton.jit
7+
def asin_kernel(
8+
x_ptr,
9+
y_ptr,
10+
n_elements,
11+
BLOCK_SIZE: tl.constexpr,
12+
):
13+
pid = tl.program_id(axis=0)
14+
block_start = pid * BLOCK_SIZE
15+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
16+
mask = offsets < n_elements
17+
x = tl.load(x_ptr + offsets, mask=mask)
18+
x = libdevice.asin(x)
19+
tl.store(y_ptr + offsets, x, mask=mask)

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
6666
@safetestset "low_memory_dropout" include(
6767
"integration/triton/low_memory_dropout.jl"
6868
)
69-
# @safetestset "layer norm" include("integration/triton/layer_norm.jl")
69+
@safetestset "layer norm" include("integration/triton/layer_norm.jl")
7070
# @safetestset "attention" include("integration/triton/attention.jl")
71-
# @safetestset "libdevice" include("integration/triton/libdevice.jl")
71+
@safetestset "libdevice" include("integration/triton/libdevice.jl")
7272
# @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl")
7373
# @safetestset "persistant matmul" include(
7474
# "integration/triton/persistant_matmul.jl"

0 commit comments

Comments
 (0)