Skip to content

Commit e5ea2f9

Browse files
committed
test: add some triton tests
1 parent 954f257 commit e5ea2f9

File tree

7 files changed

+236
-0
lines changed

7 files changed

+236
-0
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using PythonCall, Reactant, Test
2+
3+
pyimport("sys").path.append(@__DIR__)
4+
5+
low_memory_dropout_kernel = pyimport("low_memory_dropout").seeded_dropout_kernel
6+
7+
const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA")
8+
9+
function seeded_dropout(x::AbstractVector{T}, p::Number, seed) where {T}
10+
output = similar(x)
11+
mask = similar(x, Bool)
12+
low_memory_dropout_kernel(
13+
x,
14+
output,
15+
mask,
16+
length(x),
17+
p,
18+
seed,
19+
1024;
20+
grid=(cld(length(x), 1024),),
21+
blocks=(1024,),
22+
)
23+
return output, mask
24+
end
25+
26+
function apply_dropout(x::AbstractVector{T}, mask::AbstractVector, p::Number) where {T}
27+
return x .* mask ./ (1 - p)
28+
end
29+
30+
@testset "low_memory_dropout" begin
31+
if RunningOnCUDA
32+
x_ra = Reactant.to_rarray(rand(Float32, 2056))
33+
34+
out, mask = @jit seeded_dropout(x_ra, 0.25f0, ConcreteRNumber(123))
35+
36+
@test @jit(apply_dropout(x_ra, mask, 0.25f0)) out
37+
end
38+
end
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import triton
2+
import triton.language as tl
3+
4+
5+
@triton.jit
6+
def seeded_dropout_kernel(
7+
x_ptr,
8+
output_ptr,
9+
mask_ptr,
10+
n_elements,
11+
p,
12+
seed,
13+
BLOCK_SIZE: tl.constexpr,
14+
):
15+
# compute memory offsets of elements handled by this instance
16+
pid = tl.program_id(axis=0)
17+
block_start = pid * BLOCK_SIZE
18+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
19+
# load data from x
20+
mask = offsets < n_elements
21+
x = tl.load(x_ptr + offsets, mask=mask)
22+
# randomly prune it
23+
random = tl.rand(seed, offsets)
24+
x_keep = random > p
25+
# write-back
26+
output = tl.where(x_keep, x / (1 - p), 0.0)
27+
mask_out = tl.where(x_keep, 1.0, 0.0)
28+
tl.store(output_ptr + offsets, output, mask=mask)
29+
tl.store(mask_ptr + offsets, mask_out, mask=mask)

test/integration/triton/softmax.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using PythonCall, Reactant, Test
2+
3+
pyimport("sys").path.append(@__DIR__)
4+
5+
softmax_kernel = pyimport("softmax").softmax_kernel
6+
7+
const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA")
8+
9+
function softmax_naive(x::AbstractMatrix{T}) where {T}
10+
x_max = maximum(x; dims=1)
11+
z = x .- x_max
12+
num = exp.(z)
13+
denom = sum(num; dims=1)
14+
return num ./ denom
15+
end
16+
17+
function softmax_triton(x::AbstractMatrix{T}) where {T}
18+
x_transposed = permutedims(x, (2, 1)) # match python array layout
19+
out = similar(x_transposed)
20+
n_rows, n_cols = size(x_transposed)
21+
22+
function grid_fn(metadata)
23+
occupancy = (
24+
metadata.device_properties.regs_per_block ÷
25+
(metadata.num_regs * metadata.device_properties.warp_size * metadata.num_warps)
26+
)
27+
28+
num_programs = min(
29+
metadata.device_properties.multi_processor_count * min(
30+
occupancy,
31+
metadata.device_properties.shared_mem_per_block ÷ metadata.metadata.shared,
32+
),
33+
n_rows,
34+
)
35+
return num_programs
36+
end
37+
38+
softmax_kernel(
39+
out,
40+
x_transposed,
41+
Reactant.rowmajor_stride(x_transposed, 1),
42+
Reactant.rowmajor_stride(out, 1),
43+
n_rows,
44+
n_cols,
45+
BLOCK_SIZE,
46+
num_stages;
47+
grid=grid_fn,
48+
blocks=(BLOCK_SIZE,),
49+
)
50+
51+
return permutedims(out, (2, 1))
52+
end
53+
54+
@testset "softmax" begin
55+
if RunningOnCUDA
56+
x_ra = Reactant.to_rarray(rand(Float32, 132, 2056))
57+
58+
@test @jit(softmax_triton(x_ra)) @jit(softmax_naive(x_ra))
59+
end
60+
end

test/integration/triton/softmax.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import triton
2+
import triton.language as tl
3+
4+
5+
@triton.jit
6+
def softmax_kernel(
7+
output_ptr,
8+
input_ptr,
9+
input_row_stride,
10+
output_row_stride,
11+
n_rows,
12+
n_cols,
13+
BLOCK_SIZE: tl.constexpr,
14+
num_stages: tl.constexpr,
15+
):
16+
# starting row of the program
17+
row_start = tl.program_id(0)
18+
row_step = tl.num_programs(0)
19+
for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
20+
# The stride represents how much we need to increase the pointer to advance 1 row
21+
row_start_ptr = input_ptr + row_idx * input_row_stride
22+
# The block size is the next power of two greater than n_cols, so we can fit each
23+
# row in a single block
24+
col_offsets = tl.arange(0, BLOCK_SIZE)
25+
input_ptrs = row_start_ptr + col_offsets
26+
# Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
27+
mask = col_offsets < n_cols
28+
row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
29+
# Subtract maximum for numerical stability
30+
row_minus_max = row - tl.max(row, axis=0)
31+
# Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
32+
numerator = tl.exp(row_minus_max)
33+
denominator = tl.sum(numerator, axis=0)
34+
softmax_output = numerator / denominator
35+
# Write back output to DRAM
36+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
37+
output_ptrs = output_row_start_ptr + col_offsets
38+
tl.store(output_ptrs, softmax_output, mask=mask)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using PythonCall, Reactant, Test
2+
3+
pyimport("sys").path.append(@__DIR__)
4+
5+
add_kernel = pyimport("vector_add").add_kernel
6+
7+
const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA")
8+
9+
function vector_add_triton(x::AbstractVector{T}, y::AbstractVector{T}) where {T}
10+
out = similar(x)
11+
add_kernel(x, y, out, length(x), 1024; grid=(cld(length(x), 1024),), blocks=(1024,))
12+
return out
13+
end
14+
15+
@testset "vector_add" begin
16+
if RunningOnCUDA
17+
x_ra = Reactant.to_rarray(rand(Float32, 2096))
18+
y_ra = Reactant.to_rarray(rand(Float32, 2096))
19+
20+
@test @jit(vector_add_triton(x_ra, y_ra)) @jit(x_ra .+ y_ra)
21+
end
22+
end
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import triton
2+
import triton.language as tl
3+
4+
5+
@triton.jit
6+
def add_kernel(
7+
x_ptr, # *Pointer* to first input vector.
8+
y_ptr, # *Pointer* to second input vector.
9+
output_ptr, # *Pointer* to output vector.
10+
n_elements, # Size of the vector.
11+
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
12+
# NOTE: `constexpr` so it can be used as a shape value.
13+
):
14+
# There are multiple 'programs' processing different data. We identify which program
15+
# we are here:
16+
pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.
17+
# This program will process inputs that are offset from the initial data.
18+
# For instance, if you had a vector of length 256 and block_size of 64, the programs
19+
# would each access the elements [0:64, 64:128, 128:192, 192:256].
20+
# Note that offsets is a list of pointers:
21+
block_start = pid * BLOCK_SIZE
22+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
23+
# Create a mask to guard memory operations against out-of-bounds accesses.
24+
mask = offsets < n_elements
25+
# Load x and y from DRAM, masking out any extra elements in case the input is not a
26+
# multiple of the block size.
27+
x = tl.load(x_ptr + offsets, mask=mask)
28+
y = tl.load(y_ptr + offsets, mask=mask)
29+
output = x + y
30+
# Write x + y back to DRAM.
31+
tl.store(output_ptr + offsets, output, mask=mask)

test/runtests.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5959
nranks = 2
6060
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
6161
end
62+
@testset "Triton" begin
63+
@safetestset "vector_add" include("integration/triton/vector_add.jl")
64+
@safetestset "softmax" include("integration/triton/softmax.jl")
65+
# @safetestset "matmul" include("integration/triton/matmul.jl")
66+
@safetestset "low_memory_dropout" include(
67+
"integration/triton/low_memory_dropout.jl"
68+
)
69+
# @safetestset "layer norm" include("integration/triton/layer_norm.jl")
70+
# @safetestset "attention" include("integration/triton/attention.jl")
71+
# @safetestset "libdevice" include("integration/triton/libdevice.jl")
72+
# @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl")
73+
# @safetestset "persistant matmul" include(
74+
# "integration/triton/persistant_matmul.jl"
75+
# )
76+
# @safetestset "block scaled matmul" include(
77+
# "integration/triton/block_scaled_matmul.jl"
78+
# )
79+
end
6280
end
6381

6482
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"

0 commit comments

Comments
 (0)