diff --git a/CondaPkg.toml b/CondaPkg.toml index b1db4f8e75..00aa12cb4a 100644 --- a/CondaPkg.toml +++ b/CondaPkg.toml @@ -5,3 +5,4 @@ python = "<=3.13,>=3.9,<4" jax = ">= 0.6" tensorflow = ">= 2.17" numpy = ">= 2" +triton = ">= 3.4" diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index f936c8fffa..e97efe1436 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "47d57c1cea7b24e210ad75aee6e7c3f93d89ff78" +ENZYMEXLA_COMMIT = "defe9ed6f939cc22a7715f2b8c98a39d9e51e2c9" ENZYMEXLA_SHA256 = "" diff --git a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl index 1f10630808..af3852ce2e 100644 --- a/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt/ReactantPythonCallExt.jl @@ -1,14 +1,20 @@ module ReactantPythonCallExt -using PythonCall: PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist +using PythonCall: + PythonCall, Py, pyconvert, pydict, pyfunc, pyimport, pylist, pyisinstance, pytuple using Reactant: Reactant, TracedRArray, TracedRNumber, @reactant_overlay using Reactant.Ops: @opcall +using Reactant_jll: Reactant_jll const jaxptr = Ref{Py}() const jnpptr = Ref{Py}() const JAX_TRACING_SUPPORTED = Ref{Bool}(false) +const tritonptr = Ref{Py}() + +const TRITON_COMPILE_SUPPORTED = Ref{Bool}(false) + const tfptr = Ref{Py}() const tf2xlaptr = Ref{Py}() const npptr = Ref{Py}() @@ -33,6 +39,28 @@ const NUMPY_SIMPLE_TYPES = Dict( ComplexF64 => :complex64, ) +const MLIR_TYPE_STRING = Dict( + Float64 => "fp64", + Float32 => "fp32", + Float16 => "fp16", + Int64 => "i64", + Int32 => "i32", + Int16 => "i16", + Int8 => "i8", + UInt64 => "ui64", + UInt32 => "ui32", + UInt16 => "ui16", + UInt8 => "ui8", + Bool => "i1", + Reactant.F8E4M3FN => "fp8e4nv", + Reactant.F8E5M2FNUZ => "fp8e5b16", + Reactant.F8E4M3FNUZ => "fp8e4b8", + Reactant.F8E5M2 => "fp8e5", +) +if isdefined(Core, :BFloat16) + MLIR_TYPE_STRING[Core.BFloat16] = "bf16" +end + function __init__() try jaxptr[] = pyimport("jax") @@ -43,6 +71,14 @@ function __init__() be supported." exception = (err, catch_backtrace()) end + try + tritonptr[] = pyimport("triton") + TRITON_COMPILE_SUPPORTED[] = true + catch err + @warn "Failed to import triton. Compiling jax functions with triton won't be \ + supported." exception = (err, catch_backtrace()) + end + try tfptr[] = pyimport("tensorflow") tfptr[].config.set_visible_devices(pylist(); device_type="GPU") diff --git a/ext/ReactantPythonCallExt/overlays.jl b/ext/ReactantPythonCallExt/overlays.jl index 20ffa7384f..ca5bcfcea5 100644 --- a/ext/ReactantPythonCallExt/overlays.jl +++ b/ext/ReactantPythonCallExt/overlays.jl @@ -1,7 +1,7 @@ -@reactant_overlay function PythonCall.pycall(f::Py, args...) +@reactant_overlay function PythonCall.pycall(f::Py, args...; kwargs...) if Reactant.looped_any(Reactant.use_overlayed_version, args) - return pycall_with_jax_tracing(f, args...) + return overlayed_pycall(f, args...; kwargs...) else - return Base.inferencebarrier(PythonCall.pycall)(f, args...) + return Base.inferencebarrier(PythonCall.pycall)(f, args...; kwargs...) end end diff --git a/ext/ReactantPythonCallExt/pycall.jl b/ext/ReactantPythonCallExt/pycall.jl index 23674d9155..51c395dba5 100644 --- a/ext/ReactantPythonCallExt/pycall.jl +++ b/ext/ReactantPythonCallExt/pycall.jl @@ -7,7 +7,18 @@ function Reactant.convert_to_jax_dtype_struct(x::Union{TracedRArray,TracedRNumbe ) end -function pycall_with_jax_tracing(f::Py, args...) +function overlayed_pycall(f::Py, args...; kwargs...) + @assert JAX_TRACING_SUPPORTED[] || TRITON_COMPILE_SUPPORTED[] + # TODO: check for Autotuner and Heutistics as well + if TRITON_COMPILE_SUPPORTED[] && pyisinstance(f, tritonptr[].JITFunction) + return overlayed_pycall_with_triton(f, args...; kwargs...) + else + @assert isempty(kwargs) "`kwargs` are not supported for jax traced functions." + return overlayed_pycall_with_jax_tracing(f, args...) + end +end + +function overlayed_pycall_with_jax_tracing(f::Py, args...) JAX_TRACING_SUPPORTED[] || throw("jax could not be loaded.") seen_args = Reactant.OrderedIdDict() @@ -35,3 +46,144 @@ function pycall_with_jax_tracing(f::Py, args...) res = @opcall hlo_call(pyconvert(String, lowered.as_text()), linear_args...) return length(res) == 0 ? nothing : (length(res) == 1 ? res[1] : res) end + +struct TritonMetadata{CK,MD,DP} + compiled_kernel::CK + metadata::MD + device_properties::DP + num_warps::Int + num_stages::Int + num_ctas::Int + num_regs::Int + num_spills::Int + max_num_threads::Int +end + +canonicalize_grid(grid_fn, metadata) = canonicalize_grid(grid_fn(metadata), metadata) +canonicalize_grid(grid::Integer, metadata) = canonicalize_grid((grid,), metadata) +function canonicalize_grid(grid::Dims{N}, metadata) where {N} + @assert N <= 3 + @assert all(grid .> 0) + return (grid..., ntuple(_ -> 1, 3 - N)...) +end + +signature_string(::TracedRArray{T}) where {T} = "*$(MLIR_TYPE_STRING[T])", nothing +signature_string(::TracedRNumber{T}) where {T} = "$(MLIR_TYPE_STRING[T])", nothing +signature_string(x::T) where {T<:Number} = string(x), x +signature_string(x) = error("Unsupported argument type: $(typeof(x))") + +# TODO: better name for hints? +function overlayed_pycall_with_triton( + kernel::Py, + args...; + grid, + num_warps::Integer=4, + num_stages::Integer=3, + num_ctas::Integer=1, + hints=nothing, +) + @assert num_ctas == 1 "TODO: num_ctas > 1 not supported" + triton = tritonptr[] + + mapped = map(signature_string, args) + signature = first.(mapped) + # TODO: are hints actually correctly set? + hints = + hints === nothing ? Dict() : Dict(kernel.arg_names[i - 1] => v for (i, v) in hints) + constants = Dict( + kernel.arg_names[i - 1] => constant for + (i, constant) in enumerate(last.(mapped)) if constant !== nothing + ) + for (k, v) in hints + v == 1 && (constants[kernel.arg_names[k - 1]] = v) + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + sigmap = Dict(kernel.arg_names[i - 1] => sig for (i, sig) in enumerate(signature)) + for k in keys(constants) + sigmap[k] = "constexpr" + end + + for h in values(hints) + @assert h in (1, 16) "Only 1 and 16 are valid hints, got $h" + end + attrs = Dict(k => [["tt.divisibility", 16]] for (k, v) in hints if v == 16) + + src = triton.compiler.ASTSource(; + fn=kernel, constexprs=constants, signature=sigmap, attrs=attrs + ) + + # TODO: pass the device/client here from `compile` + # TODO: cluster dims + client = Reactant.XLA.default_backend() + @assert Reactant.XLA.platform_name(client) == "cuda" + device = Reactant.XLA.default_device(client) + device_properties = Reactant.XLA.device_properties(device) + + target = triton.backends.compiler.GPUTarget( + Reactant.XLA.platform_name(client), + parse(Int, "$(device_properties.major)$(device_properties.minor)"), + device_properties.warp_size, + ) + backend = triton.compiler.make_backend(target) + options = backend.parse_options( + pydict( + "num_warps" => num_warps, + "num_stages" => num_stages, + "num_ctas" => num_ctas, + "extern_libs" => pytuple((pytuple(("libdevice", Reactant_jll.libdevice)),)), + ), + ) + + # Currently we are doing a double compilation here. can we do better? + # we are compiling here + lowering again inside enzymejax + compiled_kernel = triton.compile(src; target=target, options=options.__dict__) + + cubin = pyconvert(Vector{UInt8}, compiled_kernel.asm["cubin"]) + fname = pyconvert(String, compiled_kernel.metadata.name) + n_regs, n_spills, n_max_threads = Ref{Int32}(), Ref{Int32}(), Ref{Int32}() + GC.@preserve cubin fname n_regs n_spills n_max_threads begin + @ccall Reactant.MLIR.API.mlir_c.ReactantCudaGetRegsSpillsMaxThreadsFromBinary( + cubin::Ptr{Cvoid}, + fname::Cstring, + n_regs::Ptr{Int32}, + n_spills::Ptr{Int32}, + n_max_threads::Ptr{Int32}, + )::Cvoid + end + + metadata = TritonMetadata( + compiled_kernel, + compiled_kernel.metadata, + device_properties, + num_warps, + num_stages, + num_ctas, + Int(n_regs[]), + Int(n_spills[]), + Int(n_max_threads[]), + ) + + grid = canonicalize_grid(grid, metadata) + + # TODO: actual cluster_x/y/z + + return @opcall triton_call( + pyconvert(String, compiled_kernel.asm["source"]), + filter(x -> x isa Reactant.TracedType, args)...; + func_name=fname, + grid_x=@opcall(constant(grid[1])), + grid_y=@opcall(constant(grid[2])), + grid_z=@opcall(constant(grid[3])), + block_x=@opcall(constant(num_warps * device_properties.warp_size)), + block_y=@opcall(constant(1)), + block_z=@opcall(constant(1)), + cluster_x=@opcall(constant(1)), + cluster_y=@opcall(constant(1)), + cluster_z=@opcall(constant(1)), + num_ctas, + num_warps, + threads_per_warp=device_properties.warp_size, + enable_source_remat=false, + ) +end diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index e8cac78be6..e545aa6550 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -229,6 +229,8 @@ function CompileOptions(; :canonicalize, :just_batch, :none, + :no_triton, + :before_triton_lowering, ] end diff --git a/src/Compiler.jl b/src/Compiler.jl index bea2f19472..067d0ad2a9 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -704,6 +704,8 @@ function optimization_passes( lower_comms::Bool=true, max_constant_threshold::Int=1024, backend::String="gpu", + enable_triton_passes::Bool=false, + device_properties::Union{Nothing,XLA.DeviceProperties}=nothing, ) transform_passes_list = [ "patterns=compare_op_canon<16>", @@ -1312,9 +1314,115 @@ function optimization_passes( push!(passes, "remove-duplicate-func-def") end push!(passes, func_passes) + if enable_triton_passes && backend == "cuda" + push!(passes, triton_optimization_passes(device_properties)) + end return join(passes, ',') end +# https://github.com/triton-lang/triton/blob/8ee584014e9570ba608809c42dc2060fdd214a98/python/src/passes.cc +# To get the latest passes run triton with MLIR_ENABLE_DUMP=1 and then extract the passes +function triton_optimization_passes(device_properties) + @assert device_properties !== nothing "Device properties must be provided to run \ + triton passes. This might happen if you are \ + compiling a triton kernel for non-cuda backend." + major_version = device_properties.major + minor_version = device_properties.minor + + passes_group_1 = join( + [ + "canonicalize", + "triton-rewrite-tensor-pointer", + "canonicalize", + "triton-combine", + "triton-reorder-broadcast", + "cse", + "symbol-dce", + "triton-loop-unroll", + "convert-triton-to-triton-gpu-preserving-module-attributes{target=cuda:$(major_version)$(minor_version)}", + "tritongpu-coalesce", + "tritongpu-F32DotTC", + "triton-nvidia-gpu-plan-cta", + "tritongpu-remove-layout-conversions", + "tritongpu-optimize-thread-locality", + "tritongpu-accelerate-matmul", + "tritongpu-remove-layout-conversions", + "tritongpu-optimize-dot-operands", + "canonicalize", + "triton-nvidia-optimize-descriptor-encoding", + "triton-loop-aware-cse", + "tritongpu-fuse-nested-loops", + "canonicalize", + "triton-licm", + "tritongpu-optimize-accumulator-init", + "tritongpu-hoist-tmem-alloc", + "tritongpu-promote-lhs-to-tmem", + "tritongpu-assign-latencies", + "tritongpu-schedule-loops", + "tritongpu-automatic-warp-specialization", + "tritongpu-partition-scheduling", + "tritongpu-load-mma-specialization", + "tritongpu-rewrite-partition-dependencies", + "sccp", + "cse", + "tritongpu-partition-loops", + "tritongpu-optimize-partition-warps", + "tritongpu-schedule-loops", + "tritongpu-pipeline", + "tritongpu-combine-tensor-select-and-if", + "triton-nvidia-gpu-remove-tmem-tokens", + "canonicalize", + "triton-loop-aware-cse", + "tritongpu-prefetch", + "tritongpu-optimize-dot-operands", + "canonicalize", + "tritongpu-coalesce-async-copy", + "triton-nvidia-optimize-tmem-layouts", + "tritongpu-remove-layout-conversions", + "triton-nvidia-interleave-tmem", + "tritongpu-reduce-data-duplication", + "tritongpu-reorder-instructions", + "triton-loop-aware-cse", + "symbol-dce", + "triton-nvidia-tma-lowering", + "triton-nvidia-gpu-fence-insertion", + "sccp", + "canonicalize", + "triton-nvidia-mma-lowering", + "tritongpu-combine-tensor-select-and-if", + "tritongpu-allocate-warp-groups", + "convert-scf-to-cf", + "allocate-shared-memory", + "triton-tensor-memory-allocation", + "tritongpu-global-scratch-memory-allocation", + "convert-triton-gpu-to-llvm", + ], + ",", + ) + passes_group_2 = join( + [ + "canonicalize", + "cse", + "convert-nv-gpu-to-llvm", + "convert-warp-specialize-to-llvm", + "reconcile-unrealized-casts", + "canonicalize", + "cse", + "symbol-dce", + "enable-line-info", + ], + ",", + ) + return join( + [ + "enzymexla_tt_ext.module(builtin.module($(passes_group_1)))", + "triton-augment-function-with-extra-arguments", + "enzymexla_tt_ext.module(builtin.module($(passes_group_2)))", + ], + ",", + ) +end + # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize,arith-raise{stablehlo=true}\"}" @@ -1660,6 +1768,9 @@ function compile_mlir!( toolkit = XLA.CUDA_DATA_DIR[] + default_device = XLA.default_device(client) + device_properties = XLA.device_properties(default_device) + if backend == "cpu" || backend == "tpu" kern = "lower-kernel{backend=cpu},canonicalize" if backend == "tpu" @@ -1674,9 +1785,7 @@ function compile_mlir!( "lower-kernel,canonicalize" end - device_properties = XLA.device_properties(XLA.default_device(client)) cubinChip = "sm_$(device_properties.major)$(device_properties.minor)" - if DEBUG_KERNEL[] curesulthandler = dlsym( Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" @@ -1701,10 +1810,31 @@ function compile_mlir!( end opt_passes = optimization_passes( - compile_options; sroa=true, recognize_comms, lower_comms, backend + compile_options; + sroa=true, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, + device_properties, ) opt_passes2 = optimization_passes( - compile_options; sroa=false, recognize_comms, lower_comms, backend + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=false, + device_properties, + ) + opt_passes_with_triton = optimization_passes( + compile_options; + sroa=false, + recognize_comms, + lower_comms, + backend, + enable_triton_passes=true, + device_properties, ) raise_passes = if raise isa String @@ -1719,15 +1849,16 @@ function compile_mlir!( opt_passes2 if DUS_TO_CONCAT[] - opt_passes3 = optimization_passes( + opt_passes_dus_to_concat = optimization_passes( compile_options; sroa=false, dus_to_concat=true, recognize_comms, lower_comms, backend, + device_properties, ) - result = result * "," * opt_passes3 + result = result * "," * opt_passes_dus_to_concat end result else @@ -1755,6 +1886,8 @@ function compile_mlir!( [ "mark-func-memory-effects", opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -1767,6 +1900,7 @@ function compile_mlir!( legalize_chlo_to_stablehlo..., opt_passes2, lower_enzymexla_linalg_pass, + "lower-triton-extension-ops", jit, ] else @@ -1776,15 +1910,17 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, + "lower-triton-extension-ops", jit, ] end, @@ -1792,7 +1928,32 @@ function compile_mlir!( ), "all", ) - elseif compile_options.optimization_passes === :before_kernel + elseif compile_options.optimization_passes === :no_triton + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + ["mark-func-memory-effects", opt_passes] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + opt_passes2, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + legalize_chlo_to_stablehlo..., + opt_passes2, + ] + end, + ',', + ), + "no_triton", + ) + elseif compile_options.optimization_passes === :before_triton_lowering run_pass_pipeline!( mod, join( @@ -1805,12 +1966,38 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, + opt_passes_with_triton, + "canonicalize", + "remove-unnecessary-enzyme-ops", + "enzyme-simplify-math", + legalize_chlo_to_stablehlo..., opt_passes2, + ] + end, + ',', + ), + "before_triton_lowering", + ) + elseif compile_options.optimization_passes === :before_kernel + run_pass_pipeline!( + mod, + join( + if compile_options.raise_first + ["mark-func-memory-effects", opt_passes] + else + [ + "mark-func-memory-effects", + opt_passes, + "enzyme-batch", + opt_passes2, + enzyme_pass, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", ] end, ',', @@ -1824,7 +2011,8 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -1844,12 +2032,13 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, raise_passes, ] @@ -1871,12 +2060,13 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., opt_passes2, + "lower-triton", kern, ] end, @@ -1894,7 +2084,7 @@ function compile_mlir!( "enzyme-batch", opt_passes2, enzyme_pass, - opt_passes2, + opt_passes_with_triton, "canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", @@ -1936,8 +2126,9 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes_with_triton, lower_enzymexla_linalg_pass, + "lower-triton-extension-ops", jit, ] else @@ -1949,10 +2140,12 @@ function compile_mlir!( "remove-unnecessary-enzyme-ops", "enzyme-simplify-math", legalize_chlo_to_stablehlo..., - opt_passes2, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, + "lower-triton-extension-ops", jit, ] end, @@ -1967,7 +2160,8 @@ function compile_mlir!( if compile_options.raise_first [ "mark-func-memory-effects", - opt_passes, + opt_passes_with_triton, + "lower-triton", kern, raise_passes, "enzyme-batch", @@ -1975,6 +2169,7 @@ function compile_mlir!( enzyme_pass, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math", lower_enzymexla_linalg_pass, + "lower-triton-extension-ops", jit, ] else @@ -1982,12 +2177,14 @@ function compile_mlir!( "mark-func-memory-effects", opt_passes, "enzyme-batch", - opt_passes2, + opt_passes_with_triton, enzyme_pass, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math", + "lower-triton", kern, raise_passes, lower_enzymexla_linalg_pass, + "lower-triton-extension-ops", jit, ] end, @@ -2018,6 +2215,7 @@ function compile_mlir!( recognize_comms, lower_comms, backend, + device_properties, ), "post_op_transpose_reshape", ) diff --git a/src/Ops.jl b/src/Ops.jl index 22bf67679b..f0b8d29332 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3,7 +3,7 @@ # Julia and Reactant semantics should be considered on the higher abstractions that use these module Ops using ..MLIR: MLIR -using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla +using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla, enzymexla_tt_ext using ..Reactant: Reactant, TracedRArray, @@ -1811,8 +1811,181 @@ end end # Generate a unique name given a module hash and a function name. -function _hlo_call_name(orig_name, module_suffix) - return orig_name * "_hlo_call_" * module_suffix +_new_function_name(orig_name, module_suffix) = orig_name * "_call_" * module_suffix + +function _extract_function( + code::String; + func_name::String="main", + func_op_kind::String="func.func", + location::MLIR.IR.Location=MLIR.IR.Location(), +) + module_suffix = string(hash(code); base=16) + name_to_call = func_name * "_call_" * module_suffix + mod_name = func_name * "_module_" * module_suffix + symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) + + use_ttext_module = split(func_op_kind, ".")[1] == "tt" + + if use_ttext_module + tt_mod_name = func_name * "_tt_module_" * module_suffix + tt_region = MLIR.IR.Region() + tt_block = MLIR.IR.Block() + push!(tt_region, tt_block) + triton_mod_op = enzymexla_tt_ext.module_(; + location, bodyRegion=tt_region, sym_name=tt_mod_name + ) + MLIR.IR.rmfromparent!(triton_mod_op) + push!(MLIR.IR.body(MLIR.IR.mmodule()), triton_mod_op) # insert into parent module + + region = MLIR.IR.Region() + push!(region, MLIR.IR.Block()) + moduleop = MLIR.Dialects.builtin.module_(; + location, bodyRegion=region, sym_name=mod_name + ) + MLIR.IR.rmfromparent!(moduleop) + push!(tt_block, moduleop) # insert into triton module + + top_level_block = MLIR.IR.Block( + MLIR.API.mlirModuleGetBody(MLIR.API.mlirModuleFromOperation(moduleop)), false + ) + fn = nothing + + symref = MLIR.IR.SymbolRefAttribute( + tt_mod_name, + MLIR.IR.Attribute[ + MLIR.IR.FlatSymbolRefAttribute(mod_name), + MLIR.IR.FlatSymbolRefAttribute(name_to_call), + ], + ) + else + current_module = MLIR.IR.mmodule() + moduleop = MLIR.IR.Operation(current_module) + top_level_block = MLIR.IR.body(current_module) + fn = MLIR.IR.lookup(MLIR.IR.SymbolTable(moduleop), name_to_call) + symref = MLIR.IR.FlatSymbolRefAttribute(name_to_call) + end + + if isnothing(fn) + new_mod = parse(MLIR.IR.Module, code) + new_mod_op = MLIR.IR.Operation(new_mod) + body = MLIR.IR.body(new_mod) + + operations = collect(MLIR.IR.OperationIterator(body)) + idx = Base.findfirst(op -> MLIR.IR.name(op) == func_op_kind, operations) + @assert idx !== nothing + op = operations[idx] + + fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) + fn_name == func_name && (fn = op) + + res = MLIR.IR.LogicalResult( + MLIR.API.mlirSymbolTableReplaceAllSymbolUses(fn_name, name_to_call, new_mod_op) + ) + @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" + + if !use_ttext_module + # Set function private + MLIR.IR.attr!( + op, + MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), + MLIR.IR.Attribute("private"), + ) + end + + # Change function name + MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(name_to_call)) + + for op in operations + MLIR.IR.rmfromparent!(op) + push!(top_level_block, op) + end + end + + if isnothing(fn) + error("hlo_call: could not find function $func_name in the provided module") + end + + return fn, symref, moduleop +end + +function triton_call( + mlir_code::String, + args::Union{TracedRArray,TracedRNumber,Number}...; + func_name::String="main", + grid_x::TracedRNumber{<:Integer}, + grid_y::TracedRNumber{<:Integer}, + grid_z::TracedRNumber{<:Integer}, + block_x::TracedRNumber{<:Integer}, + block_y::TracedRNumber{<:Integer}, + block_z::TracedRNumber{<:Integer}, + cluster_x::TracedRNumber{<:Integer}, + cluster_y::TracedRNumber{<:Integer}, + cluster_z::TracedRNumber{<:Integer}, + num_ctas::Integer=1, + num_warps::Integer=4, + threads_per_warp::Integer=32, + enable_source_remat::Bool=false, + location=mlir_stacktrace("triton_call", @__FILE__, @__LINE__), +) + _, symref, modop = _extract_function( + mlir_code; func_name, func_op_kind="tt.func", location + ) + + MLIR.IR.attr!(modop, "enzymexla.ttg.num-warps", MLIR.IR.Attribute(Int32(num_warps))) + MLIR.IR.attr!(modop, "enzymexla.ttg.num-ctas", MLIR.IR.Attribute(Int32(num_ctas))) + MLIR.IR.attr!( + modop, "enzymexla.ttg.threads-per-warp", MLIR.IR.Attribute(Int32(threads_per_warp)) + ) + if enable_source_remat + MLIR.IR.attr!(modop, "enzymexla.ttg.enable-source-remat", MLIR.IR.UnitAttribute()) + end + + result_types = MLIR.IR.Type[] + output_operand_aliases = MLIR.IR.Attribute[] + output_to_arg = Int[] + for (i, arg) in enumerate(args) + if arg isa TracedRArray + push!(result_types, mlir_type(typeof(arg), size(arg))) + push!( + output_operand_aliases, + MLIR.IR.Attribute( + MLIR.API.stablehloOutputOperandAliasGet( + MLIR.IR.context(), 1, Int64[i - 1], Int64(i - 1), 0, C_NULL + ), + ), + ) + push!(output_to_arg, i) + end + end + + results = enzymexla_tt_ext.call( + grid_x.mlir_data, + grid_y.mlir_data, + grid_z.mlir_data, + block_x.mlir_data, + block_y.mlir_data, + block_z.mlir_data, + cluster_x.mlir_data, + cluster_y.mlir_data, + cluster_z.mlir_data, + [Reactant.TracedUtils.get_mlir_data(a) for a in args]; + fn=symref, + result_0=result_types, + location, + output_operand_aliases, + ) + + array_results = () + for i in 1:MLIR.IR.nresults(results) + arg = args[output_to_arg[i]] + res = Reactant.TracedRArray{unwrapped_eltype(arg),ndims(arg)}( + (), MLIR.IR.result(results, i), size(arg) + ) + copyto!(arg, res) + array_results = (array_results..., res) + end + length(array_results) == 1 && return array_results[1] + return array_results end """ @@ -1841,69 +2014,16 @@ julia> Reactant.@jit( """ @noinline function hlo_call( code, - args...; + args::Union{TracedRArray,TracedRNumber}...; func_name="main", location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__), ) - module_suffix = string(hash(code); base=16) - name_to_call = _hlo_call_name(func_name, module_suffix) - - current_module = MLIR.IR.mmodule() - top_level_block = MLIR.IR.body(current_module) - - symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()) - - fn = MLIR.IR.lookup( - MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call - ) - if isnothing(fn) - new_mod = parse(MLIR.IR.Module, code) - new_mod_op = MLIR.IR.Operation(new_mod) - body = MLIR.IR.body(new_mod) - - operations = collect(MLIR.IR.OperationIterator(body)) - for op in operations - if MLIR.IR.name(op) == "func.func" - fn_name = String(MLIR.IR.attr(op, symbol_attr_name)) - if fn_name == func_name - fn = op - end - - new_name = _hlo_call_name(fn_name, module_suffix) - res = MLIR.IR.LogicalResult( - MLIR.API.mlirSymbolTableReplaceAllSymbolUses( - fn_name, new_name, new_mod_op - ), - ) - @assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name" - - # Set function private - MLIR.IR.attr!( - op, - MLIR.API.mlirSymbolTableGetVisibilityAttributeName(), - MLIR.IR.Attribute("private"), - ) - - # Change function name - MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name)) - end - end - - for op in operations - MLIR.IR.rmfromparent!(op) - push!(top_level_block, op) - end - end - - if isnothing(fn) - error("hlo_call: could not find function $func_name in the provided module") - end + fn, symref, _ = _extract_function(code; func_name, func_op_kind="func.func", location) ftype_attr = MLIR.IR.attr(fn, "function_type") ftype = MLIR.IR.Type(ftype_attr) - @assert all(Base.Fix2(isa, Union{TracedRArray,TracedRNumber}), args) "hlo_call: all inputs to hlo_call should be reactant arrays or numbers" - @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name" + @assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name. Expected $(MLIR.IR.ninputs(ftype)), got $(length(args))" for (i, arg) in enumerate(args) expected_type = MLIR.IR.input(ftype, i) @@ -1915,7 +2035,7 @@ julia> Reactant.@jit( call = MLIR.Dialects.func.call( operands; result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)], - callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call), + callee=symref, location, ) diff --git a/src/Reactant.jl b/src/Reactant.jl index 69cee2b3b0..dd6add1114 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -236,6 +236,35 @@ include("stdlibs/Base.jl") # Other Integrations include("Enzyme.jl") +""" + rowmajor_strides(x::AbstractArray) + +Returns the strides of the array `x` assuming that the array is stored in row-major order. +""" +rowmajor_strides(x::AbstractArray) = rowmajor_strides(size(x)) +function rowmajor_strides(sz::NTuple{N,Int}) where {N} + strides = ntuple(_ -> 1, N) + for i in (N - 1):-1:1 + strides = Base.setindex(strides, strides[i + 1] * sz[i + 1], i) + end + return strides +end + +""" + rowmajor_stride(x::AbstractArray, i::Integer) + +Returns the stride of the array `x` at dimension `i` assuming that the array is stored in +row-major order. +""" +rowmajor_stride(x::AbstractArray, i::Integer) = rowmajor_stride(size(x), i) +function rowmajor_stride(sz::NTuple{N,Int}, i::Integer) where {N} + s = 1 + for j in (i + 1):N + s *= sz[j] + end + return s +end + export StackedBatchDuplicated, StackedBatchDuplicatedNoNeed const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue} diff --git a/src/mlir/IR/Module.jl b/src/mlir/IR/Module.jl index 12794b30ba..c7f938d5b8 100644 --- a/src/mlir/IR/Module.jl +++ b/src/mlir/IR/Module.jl @@ -52,7 +52,8 @@ body(module_) = Block(API.mlirModuleGetBody(module_), false) Views the module as a generic operation. """ -Operation(module_::Module) = Operation(API.mlirModuleGetOperation(module_), false) +Operation(module_::Module, owned::Bool=false) = + Operation(API.mlirModuleGetOperation(module_), owned) function Base.show(io::IO, module_::Module) return show(io, Operation(module_)) diff --git a/test/integration/triton/layer_norm.jl b/test/integration/triton/layer_norm.jl new file mode 100644 index 0000000000..f9652da235 --- /dev/null +++ b/test/integration/triton/layer_norm.jl @@ -0,0 +1,71 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +layer_norm_kernel = pyimport("layer_norm").layer_norm_fwd_fused +layer_norm_kernel_v2 = pyimport("layer_norm").layer_norm_fwd_fused_simple + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function layer_norm_triton( + x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T}, simple::Bool +) where {T} + x_transposed = permutedims(x, (2, 1)) # match python array layout + y = similar(x_transposed) + M, N = size(x_transposed) + mean = similar(x_transposed, Float32, M) + rstd = similar(x_transposed, Float32, M) + + max_fused_size = 65536 ÷ sizeof(T) + block_size = min(max_fused_size, nextpow(2, N)) + + if N > block_size + throw(ArgumentError("This layer norm doesn't support feature dim >= 64KB.")) + end + + (simple ? layer_norm_kernel_v2 : layer_norm_kernel)( + x_transposed, + y, + weight, + bias, + mean, + rstd, + Reactant.rowmajor_stride(x_transposed, 1), + N, + 1.0f-5, + block_size; + num_warps=min(max(block_size ÷ 256, 1), 8), + num_ctas=1, + grid=(M,), + ) + + return permutedims(y, (2, 1)), mean, rstd +end + +function layer_norm_naive( + x::AbstractMatrix{T}, weight::AbstractVector{T}, bias::AbstractVector{T} +) where {T} + mean = sum(x; dims=1) ./ size(x, 1) + rstd = 1 ./ sqrt.(sum(abs2, x .- mean; dims=1) ./ size(x, 1) .+ 1e-5) + x_hat = (x .- mean) .* rstd + return x_hat .* weight .+ bias, vec(mean), vec(rstd) +end + +@testset "fused_layer_norm" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 257, 2056)) + weight_ra = Reactant.to_rarray(rand(Float32, 257)) + bias_ra = Reactant.to_rarray(rand(Float32, 257)) + + y_ra1, mean_ra1, rstd_ra1 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, false) + y_ra2, mean_ra2, rstd_ra2 = @jit layer_norm_naive(x_ra, weight_ra, bias_ra) + y_ra3, mean_ra3, rstd_ra3 = @jit layer_norm_triton(x_ra, weight_ra, bias_ra, true) + + @test y_ra1 ≈ y_ra2 + @test y_ra2 ≈ y_ra3 + @test mean_ra1 ≈ mean_ra2 + @test mean_ra2 ≈ mean_ra3 + @test rstd_ra1 ≈ rstd_ra2 + @test rstd_ra2 ≈ rstd_ra3 + end +end diff --git a/test/integration/triton/layer_norm.py b/test/integration/triton/layer_norm.py new file mode 100644 index 0000000000..9595491551 --- /dev/null +++ b/test/integration/triton/layer_norm.py @@ -0,0 +1,103 @@ +import triton +import triton.language as tl + + +@triton.jit +def layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) + + +@triton.jit +def layer_norm_fwd_fused_simple( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + + # Compute mean - process one element at a time + mean = 0.0 + for i in range(N): + x = tl.load(X + i).to(tl.float32) + mean += x + mean = mean / N + + # Compute variance - process one element at a time + var = 0.0 + for i in range(N): + x = tl.load(X + i).to(tl.float32) + diff = x - mean + var += diff * diff + var = var / N + rstd = 1.0 / tl.sqrt(var + eps) + + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) diff --git a/test/integration/triton/libdevice.jl b/test/integration/triton/libdevice.jl new file mode 100644 index 0000000000..89eee78e99 --- /dev/null +++ b/test/integration/triton/libdevice.jl @@ -0,0 +1,21 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +asin_kernel = pyimport("libdevice").asin_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function asin_triton(x::AbstractVector{T}) where {T} + out = similar(x) + asin_kernel(x, out, length(x), 1024; grid=(cld(length(x), 1024),)) + return out +end + +@testset "libdevice asin" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2096)) + + @test @jit(asin_triton(x_ra)) ≈ @jit(asin.(x_ra)) + end +end diff --git a/test/integration/triton/libdevice.py b/test/integration/triton/libdevice.py new file mode 100644 index 0000000000..ac9a199952 --- /dev/null +++ b/test/integration/triton/libdevice.py @@ -0,0 +1,19 @@ +import triton +import triton.language as tl +from triton.language.extra import libdevice + + +@triton.jit +def asin_kernel( + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + x = libdevice.asin(x) + tl.store(y_ptr + offsets, x, mask=mask) diff --git a/test/integration/triton/low_memory_dropout.jl b/test/integration/triton/low_memory_dropout.jl new file mode 100644 index 0000000000..48be41490b --- /dev/null +++ b/test/integration/triton/low_memory_dropout.jl @@ -0,0 +1,30 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +low_memory_dropout_kernel = pyimport("low_memory_dropout").seeded_dropout_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function seeded_dropout(x::AbstractVector{T}, p::Number, seed) where {T} + output = similar(x) + mask = similar(x, Bool) + low_memory_dropout_kernel( + x, output, mask, length(x), p, seed, 1024; grid=(cld(length(x), 1024),) + ) + return output, mask +end + +function apply_dropout(x::AbstractVector{T}, mask::AbstractVector, p::Number) where {T} + return x .* mask ./ (1 - p) +end + +@testset "low_memory_dropout" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2056)) + + out, mask = @jit seeded_dropout(x_ra, 0.25f0, ConcreteRNumber(123)) + + @test @jit(apply_dropout(x_ra, mask, 0.25f0)) ≈ out + end +end diff --git a/test/integration/triton/low_memory_dropout.py b/test/integration/triton/low_memory_dropout.py new file mode 100644 index 0000000000..ad32ac0014 --- /dev/null +++ b/test/integration/triton/low_memory_dropout.py @@ -0,0 +1,29 @@ +import triton +import triton.language as tl + + +@triton.jit +def seeded_dropout_kernel( + x_ptr, + output_ptr, + mask_ptr, + n_elements, + p, + seed, + BLOCK_SIZE: tl.constexpr, +): + # compute memory offsets of elements handled by this instance + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # load data from x + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + # randomly prune it + random = tl.rand(seed, offsets) + x_keep = random > p + # write-back + output = tl.where(x_keep, x / (1 - p), 0.0) + mask_out = tl.where(x_keep, 1.0, 0.0) + tl.store(output_ptr + offsets, output, mask=mask) + tl.store(mask_ptr + offsets, mask_out, mask=mask) diff --git a/test/integration/triton/matmul.jl b/test/integration/triton/matmul.jl new file mode 100644 index 0000000000..ea841dd771 --- /dev/null +++ b/test/integration/triton/matmul.jl @@ -0,0 +1,61 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +matmul_kernel = pyimport("matmul").matmul_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function matmul_triton(a::AbstractMatrix{T}, b::AbstractMatrix{T}) where {T} + # a: [M, K] --> aᵀ: [K, M] + # b: [K, N] --> bᵀ: [N, K] + # c: a × b [M, N] --> cᵀ: bᵀ × aᵀ [N, M] + a_transposed = permutedims(a, (2, 1)) # match python array layout + b_transposed = permutedims(b, (2, 1)) # match python array layout + @assert size(b_transposed, 2) == size(a_transposed, 1) "Inner dimensions must match \ + for matmul" + M, K = size(b_transposed) + K, N = size(a_transposed) + + out = similar(a_transposed, T, M, N) # cᵀ + + matmul_kernel( + b_transposed, + a_transposed, + out, + M, + N, + K, + Reactant.rowmajor_stride(b_transposed, 1), + Reactant.rowmajor_stride(b_transposed, 2), + Reactant.rowmajor_stride(a_transposed, 1), + Reactant.rowmajor_stride(a_transposed, 2), + Reactant.rowmajor_stride(out, 1), + Reactant.rowmajor_stride(out, 2), + 64, + 256, + 32, + 8; + grid=(cld(M, 64) * cld(N, 256),), + num_stages=4, + num_warps=4, + ) + + return permutedims(out, (2, 1)) +end + +@testset "matmul" begin + if RunningOnCUDA + @testset for M in (4, 32, 256, 1024), + K in (4, 32, 512, 2048), + N in (4, 32, 256, 1024) + + a = Reactant.to_rarray(rand(Float32, M, K)) + b = Reactant.to_rarray(rand(Float32, K, N)) + + # XXX: shared_memory???? + # XXX: seems to work correctly for small matrices + @test_broken @jit(matmul_triton(a, b)) ≈ @jit(a * b) + end + end +end diff --git a/test/integration/triton/matmul.py b/test/integration/triton/matmul.py new file mode 100644 index 0000000000..f4dafc0318 --- /dev/null +++ b/test/integration/triton/matmul.py @@ -0,0 +1,264 @@ +import triton +import triton.language as tl + + +# XXX: enable and support autotuning +# @triton.autotune( +# configs=[ +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 32, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 32, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=5, +# num_warps=2, +# ), +# # Good config for fp8 inputs. +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 256, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 256, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 256, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 128, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 64, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 64, +# "BLOCK_SIZE_N": 128, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# { +# "BLOCK_SIZE_M": 128, +# "BLOCK_SIZE_N": 32, +# "BLOCK_SIZE_K": 64, +# "GROUP_SIZE_M": 8, +# }, +# num_stages=4, +# num_warps=4, +# ), +# ], +# key=["M", "N", "K"], +# ) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Add some integer bound assumptions. + # This helps to guide integer analysis in the backend to optimize + # load/store offset address calculation + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) diff --git a/test/integration/triton/softmax.jl b/test/integration/triton/softmax.jl new file mode 100644 index 0000000000..815f390754 --- /dev/null +++ b/test/integration/triton/softmax.jl @@ -0,0 +1,61 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +softmax_kernel = pyimport("softmax").softmax_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function softmax_naive(x::AbstractMatrix{T}) where {T} + x_max = maximum(x; dims=1) + z = x .- x_max + num = exp.(z) + denom = sum(num; dims=1) + return num ./ denom +end + +function softmax_triton(x::AbstractMatrix{T}) where {T} + x_transposed = permutedims(x, (2, 1)) # match python array layout + out = similar(x_transposed) + n_rows, n_cols = size(x_transposed) + + BLOCK_SIZE = nextpow(2, n_cols) + + function grid_fn(metadata) + occupancy = ( + metadata.device_properties.regs_per_block ÷ + (metadata.num_regs * metadata.device_properties.warp_size * metadata.num_warps) + ) + + num_programs = min( + metadata.device_properties.multi_processor_count * min( + occupancy, + metadata.device_properties.shared_mem_per_block ÷ metadata.metadata.shared, + ), + n_rows, + ) + return num_programs + end + + softmax_kernel( + out, + x_transposed, + Reactant.rowmajor_stride(x_transposed, 1), + Reactant.rowmajor_stride(out, 1), + n_rows, + n_cols, + BLOCK_SIZE, + num_stages=3; + grid=grid_fn, + ) + + return permutedims(out, (2, 1)) +end + +@testset "softmax" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 132, 2056)) + + @test @jit(softmax_triton(x_ra)) ≈ @jit(softmax_naive(x_ra)) + end +end diff --git a/test/integration/triton/softmax.py b/test/integration/triton/softmax.py new file mode 100644 index 0000000000..0a80c43275 --- /dev/null +++ b/test/integration/triton/softmax.py @@ -0,0 +1,38 @@ +import triton +import triton.language as tl + + +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr, +): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) diff --git a/test/integration/triton/vector_add.jl b/test/integration/triton/vector_add.jl new file mode 100644 index 0000000000..5a96e3b785 --- /dev/null +++ b/test/integration/triton/vector_add.jl @@ -0,0 +1,22 @@ +using PythonCall, Reactant, Test + +pyimport("sys").path.append(@__DIR__) + +add_kernel = pyimport("vector_add").add_kernel + +const RunningOnCUDA = contains(string(Reactant.devices()[1]), "CUDA") + +function vector_add_triton(x::AbstractVector{T}, y::AbstractVector{T}) where {T} + out = similar(x) + add_kernel(x, y, out, length(x), 1024; grid=(cld(length(x), 1024),)) + return out +end + +@testset "vector_add" begin + if RunningOnCUDA + x_ra = Reactant.to_rarray(rand(Float32, 2096)) + y_ra = Reactant.to_rarray(rand(Float32, 2096)) + + @test @jit(vector_add_triton(x_ra, y_ra)) ≈ @jit(x_ra .+ y_ra) + end +end diff --git a/test/integration/triton/vector_add.py b/test/integration/triton/vector_add.py new file mode 100644 index 0000000000..6b04d51a7d --- /dev/null +++ b/test/integration/triton/vector_add.py @@ -0,0 +1,31 @@ +import triton +import triton.language as tl + + +@triton.jit +def add_kernel( + x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. +): + # There are multiple 'programs' processing different data. We identify which program + # we are here: + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + # This program will process inputs that are offset from the initial data. + # For instance, if you had a vector of length 256 and block_size of 64, the programs + # would each access the elements [0:64, 64:128, 128:192, 192:256]. + # Note that offsets is a list of pointers: + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Create a mask to guard memory operations against out-of-bounds accesses. + mask = offsets < n_elements + # Load x and y from DRAM, masking out any extra elements in case the input is not a + # multiple of the block size. + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + # Write x + y back to DRAM. + tl.store(output_ptr + offsets, output, mask=mask) diff --git a/test/runtests.jl b/test/runtests.jl index bc10705e43..cb60df0e1c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,25 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) if VERSION < v"1.12-" @safetestset "Zygote" include("integration/zygote.jl") end + + @testset "Triton" begin + @safetestset "vector_add" include("integration/triton/vector_add.jl") + @safetestset "softmax" include("integration/triton/softmax.jl") + # @safetestset "matmul" include("integration/triton/matmul.jl") # XXX + @safetestset "low_memory_dropout" include( + "integration/triton/low_memory_dropout.jl" + ) + @safetestset "layer norm" include("integration/triton/layer_norm.jl") + # @safetestset "attention" include("integration/triton/attention.jl") + @safetestset "libdevice" include("integration/triton/libdevice.jl") + # @safetestset "grouped gemm" include("integration/triton/grouped_gemm.jl") + # @safetestset "persistant matmul" include( + # "integration/triton/persistant_matmul.jl" + # ) + # @safetestset "block scaled matmul" include( + # "integration/triton/block_scaled_matmul.jl" + # ) + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"