Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Sep 25, 2025

import triton
import triton.language as tl


@triton.jit
def add_kernel(
    x_ptr,  # *Pointer* to first input vector.
    y_ptr,  # *Pointer* to second input vector.
    output_ptr,  # *Pointer* to output vector.
    n_elements,  # Size of the vector.
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process.
    # NOTE: `constexpr` so it can be used as a shape value.
):
    # There are multiple 'programs' processing different data. We identify which program
    # we are here:
    pid = tl.program_id(axis=0)  # We use a 1D launch grid so axis is 0.
    # This program will process inputs that are offset from the initial data.
    # For instance, if you had a vector of length 256 and block_size of 64, the programs
    # would each access the elements [0:64, 64:128, 128:192, 192:256].
    # Note that offsets is a list of pointers:
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to guard memory operations against out-of-bounds accesses.
    mask = offsets < n_elements
    # Load x and y from DRAM, masking out any extra elements in case the input is not a
    # multiple of the block size.
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # Write x + y back to DRAM.
    tl.store(output_ptr + offsets, output, mask=mask)
using PythonCall, Reactant

pyimport("sys").path.append(@__DIR__)
kernel = pyimport("vector_add").add_kernel

x = Reactant.to_rarray(rand(Float32, 1024));
y = Reactant.to_rarray(rand(Float32, 1024));
out = Reactant.to_rarray(zeros(Float32, 1024));

@code_hlo kernel(
    x,
    y,
    out,
    length(x),
    64;
    grid=cld(length(x), 64),
    num_warps=1,
    num_stages=3,
    hints=Dict(1 => 16),
)

@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from a7ece19 to f776758 Compare September 27, 2025 13:17
@avik-pal
Copy link
Collaborator Author

avik-pal commented Sep 27, 2025

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  module @tt_module_0 {
    tt.func @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {noinline = false} {
      %0 = tt.get_program_id x : i32
      %c64_i32 = arith.constant 64 : i32
      %c64_i32_0 = arith.constant 64 : i32
      %1 = arith.extsi %0 : i32 to i64
      %2 = arith.extsi %c64_i32_0 : i32 to i64
      %3 = arith.muli %1, %2 : i64
      %c2147483647_i64 = arith.constant 2147483647 : i64
      %c-2147483648_i64 = arith.constant -2147483648 : i64
      %4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
      %5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
      %6 = arith.andi %4, %5 : i1
      %7 = arith.muli %0, %c64_i32_0 : i32
      %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
      %9 = tt.splat %7 : i32 -> tensor<64xi32>
      %10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
      %11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
      %12 = arith.addi %10, %11 : tensor<64xi64>
      %c2147483647_i64_1 = arith.constant 2147483647 : i64
      %c-2147483648_i64_2 = arith.constant -2147483648 : i64
      %cst = arith.constant dense<2147483647> : tensor<64xi64>
      %13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
      %cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
      %14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
      %15 = arith.andi %13, %14 : tensor<64xi1>
      %16 = arith.addi %9, %8 : tensor<64xi32>
      %c1024_i32 = arith.constant 1024 : i32
      %cst_4 = arith.constant dense<1024> : tensor<64xi32>
      %17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
      %18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      %20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
      %21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      %23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
      %24 = arith.addf %20, %23 : tensor<64xf32>
      %25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
      %26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
      tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
      tt.return
    }
  }
  func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %c = stablehlo.constant dense<16> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    enzymexla.triton_call @tt_module_0::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    %3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
  }
}

@avik-pal
Copy link
Collaborator Author

@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from 95598f9 to 7f0afd8 Compare September 28, 2025 16:20
@avik-pal
Copy link
Collaborator Author

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
    builtin.module @add_kernel_module_e72661bb113efd0f {
      tt.func private @add_kernel_call_e72661bb113efd0f(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: !tt.ptr<f32>) attributes {enzymexla.memory_effects = ["read", "write"], noinline = false} {
        %0 = tt.get_program_id x : i32
        %c64_i32 = arith.constant 64 : i32
        %c64_i32_0 = arith.constant 64 : i32
        %1 = arith.extsi %0 : i32 to i64
        %2 = arith.extsi %c64_i32_0 : i32 to i64
        %3 = arith.muli %1, %2 : i64
        %c2147483647_i64 = arith.constant 2147483647 : i64
        %c-2147483648_i64 = arith.constant -2147483648 : i64
        %4 = arith.cmpi sle, %3, %c2147483647_i64 : i64
        %5 = arith.cmpi sge, %3, %c-2147483648_i64 : i64
        %6 = arith.andi %4, %5 : i1
        %7 = arith.muli %0, %c64_i32_0 : i32
        %8 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
        %9 = tt.splat %7 : i32 -> tensor<64xi32>
        %10 = arith.extsi %9 : tensor<64xi32> to tensor<64xi64>
        %11 = arith.extsi %8 : tensor<64xi32> to tensor<64xi64>
        %12 = arith.addi %10, %11 : tensor<64xi64>
        %c2147483647_i64_1 = arith.constant 2147483647 : i64
        %c-2147483648_i64_2 = arith.constant -2147483648 : i64
        %cst = arith.constant dense<2147483647> : tensor<64xi64>
        %13 = arith.cmpi sle, %12, %cst : tensor<64xi64>
        %cst_3 = arith.constant dense<-2147483648> : tensor<64xi64>
        %14 = arith.cmpi sge, %12, %cst_3 : tensor<64xi64>
        %15 = arith.andi %13, %14 : tensor<64xi1>
        %16 = arith.addi %9, %8 : tensor<64xi32>
        %c1024_i32 = arith.constant 1024 : i32
        %cst_4 = arith.constant dense<1024> : tensor<64xi32>
        %17 = arith.cmpi slt, %16, %cst_4 : tensor<64xi32>
        %18 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %19 = tt.addptr %18, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        %20 = tt.load %19, %17 : tensor<64x!tt.ptr<f32>>
        %21 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %22 = tt.addptr %21, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        %23 = tt.load %22, %17 : tensor<64x!tt.ptr<f32>>
        %24 = arith.addf %20, %23 : tensor<64xf32>
        %25 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
        %26 = tt.addptr %25, %16 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
        tt.store %26, %24, %17 : tensor<64x!tt.ptr<f32>>
        tt.return
      }
    }
  }
  func.func @main(%arg0: tensor<1024xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<1024xf32> {tf.aliasing_output = 1 : i32}, %arg2: tensor<1024xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %2 = stablehlo.transpose %arg2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %c = stablehlo.constant dense<16> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<0> : tensor<i64>
    triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c, %c_0, %c_1) shmem = %c_2 (%0, %1, %2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    %3 = stablehlo.transpose %0, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %4 = stablehlo.transpose %1, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    %5 = stablehlo.transpose %2, dims = [0] : (tensor<1024xf32>) -> tensor<1024xf32>
    return %3, %4, %5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>
  }
}

@avik-pal avik-pal force-pushed the ap/triton_integration branch from 7f0afd8 to 4876110 Compare September 29, 2025 20:02
@avik-pal
Copy link
Collaborator Author

module @reactant_JITFunc... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  triton_ext.module @add_kernel_tt_module_e72661bb113efd0f {
    builtin.module @add_kernel_module_e72661bb113efd0f attributes {ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.shared = 0 : i32, ttg.target = "cuda:120", ttg.tensor_memory_size = 0 : i32, "ttg.threads-per-warp" = 32 : i32, "ttg.total-num-warps" = 1 : i32} {
      llvm.mlir.global external @global_smem() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<0 x i8>
      llvm.func @add_kernel_call_e72661bb113efd0f(%arg0: !llvm.ptr<1>, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"], noinline = false, nvvm.kernel = 1 : ui1, nvvm.reqntid = array<i32: 32>, ttg.global_scratch_memory_alignment = 1 : i32, ttg.global_scratch_memory_size = 0 : i32} {
        %0 = llvm.mlir.undef : vector<1xf32>
        %1 = llvm.mlir.constant(0 : i32) : i32
        %2 = llvm.mlir.constant(32 : i32) : i32
        %3 = llvm.mlir.constant(31 : i32) : i32
        %4 = llvm.mlir.constant(0 : index) : i32
        %5 = llvm.mlir.constant(1024 : i32) : i32
        %6 = llvm.mlir.constant(64 : i32) : i32
        %7 = llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32
        %8 = llvm.mul %7, %6 : i32
        %9 = nvvm.read.ptx.sreg.tid.x : i32
        %10 = llvm.and %9, %3 : i32
        %11 = llvm.shl %10, %1 : i32
        %12 = llvm.or %1, %11 : i32
        %13 = llvm.or %12, %1 : i32
        %14 = llvm.and %13, %3 : i32
        %15 = llvm.lshr %14, %1 : i32
        %16 = llvm.xor %1, %15 : i32
        %17 = llvm.xor %1, %16 : i32
        %18 = llvm.xor %17, %1 : i32
        %19 = llvm.xor %17, %2 : i32
        %20 = llvm.add %18, %4 : i32
        %21 = llvm.add %19, %4 : i32
        %22 = llvm.add %8, %20 : i32
        %23 = llvm.add %8, %21 : i32
        %24 = llvm.icmp "slt" %22, %5 : i32
        %25 = llvm.icmp "slt" %23, %5 : i32
        %26 = llvm.getelementptr %arg0[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %27 = llvm.getelementptr %arg0[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %28 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %26, %24 : (!llvm.ptr<1>, i1) -> i32
        %29 = llvm.bitcast %28 : i32 to vector<1xf32>
        %30 = llvm.extractelement %29[%4 : i32] : vector<1xf32>
        %31 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %27, %25 : (!llvm.ptr<1>, i1) -> i32
        %32 = llvm.bitcast %31 : i32 to vector<1xf32>
        %33 = llvm.extractelement %32[%4 : i32] : vector<1xf32>
        %34 = llvm.getelementptr %arg1[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %35 = llvm.getelementptr %arg1[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %36 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %34, %24 : (!llvm.ptr<1>, i1) -> i32
        %37 = llvm.bitcast %36 : i32 to vector<1xf32>
        %38 = llvm.extractelement %37[%4 : i32] : vector<1xf32>
        %39 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "mov.u32 $0, 0x0;\0A\09@$2 ld.global.b32 { $0 }, [ $1 + 0 ];", "=r,l,b" %35, %25 : (!llvm.ptr<1>, i1) -> i32
        %40 = llvm.bitcast %39 : i32 to vector<1xf32>
        %41 = llvm.extractelement %40[%4 : i32] : vector<1xf32>
        %42 = llvm.fadd %30, %38 : f32
        %43 = llvm.fadd %33, %41 : f32
        %44 = llvm.getelementptr %arg2[%22] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %45 = llvm.getelementptr %arg2[%23] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
        %46 = llvm.insertelement %42, %0[%1 : i32] : vector<1xf32>
        %47 = llvm.bitcast %46 : vector<1xf32> to i32
        %48 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %47, %44, %24 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
        %49 = llvm.insertelement %43, %0[%1 : i32] : vector<1xf32>
        %50 = llvm.bitcast %49 : vector<1xf32> to i32
        %51 = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "@$2 st.global.b32 [ $1 + 0 ], { $0 };", "r,l,b" %50, %45, %25 : (i32, !llvm.ptr<1>, i1) -> !llvm.void
        llvm.return
      }
    }
  }
  func.func @main(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %c = stablehlo.constant dense<0> : tensor<i64>
    %c_0 = stablehlo.constant dense<1> : tensor<i64>
    %c_1 = stablehlo.constant dense<16> : tensor<i64>
    triton_ext.call @add_kernel_tt_module_e72661bb113efd0f::@add_kernel_module_e72661bb113efd0f::@add_kernel_call_e72661bb113efd0f blocks in(%c_1, %c_0, %c_0) shmem = %c (%arg0, %arg1, %arg2) : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) -> ()
    return
  }
}

@avik-pal avik-pal force-pushed the ap/triton_integration branch from 3315b07 to 4a9a1ce Compare October 1, 2025 21:33
@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from 38bbe42 to 1042dfb Compare October 16, 2025 13:52
@avik-pal avik-pal force-pushed the ap/triton_integration branch 3 times, most recently from 891a5dd to 04cbf60 Compare October 16, 2025 20:54
@avik-pal avik-pal changed the base branch from main to ap/device_props_julia October 16, 2025 20:54
@avik-pal avik-pal force-pushed the ap/device_props_julia branch from 88b93d3 to 3b5b3fa Compare October 17, 2025 12:59
@avik-pal avik-pal force-pushed the ap/triton_integration branch from 5b39be4 to 1381077 Compare October 17, 2025 13:00
@avik-pal avik-pal force-pushed the ap/device_props_julia branch from 3b5b3fa to fd9f3c9 Compare October 18, 2025 00:23
@avik-pal avik-pal force-pushed the ap/triton_integration branch from 3ed4a70 to 9dde89f Compare October 18, 2025 00:29
@avik-pal avik-pal force-pushed the ap/device_props_julia branch 2 times, most recently from 190b57e to d5372aa Compare October 19, 2025 15:00
Base automatically changed from ap/device_props_julia to main October 21, 2025 18:57
@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from 63ba127 to 3bbc37b Compare October 28, 2025 23:56
@avik-pal avik-pal force-pushed the ap/triton_integration branch 2 times, most recently from 0bb32b8 to de7b9f4 Compare November 11, 2025 07:52
feat: auto-trace triton code

feat: copy tt.func into main module [skip ci]

feat: tracing fully functional

fix: hlo_call

feat: more triton passes + keep triton func in a separate module

feat: put the tt func in a separate module and use symbol ref

feat: new triton_ext dialect

feat: triton tracing works now finally

fix: kind of working

fix: new API

feat: return values

feat: lowering triton now works

feat: triton working end to end

fix: extra export + naming

feat: allow grid/blocks via a function [skip ci]

feat: use new device properties [skip ci]

feat: correctly set strides + get n_regs

test: add some triton tests

test: layer_norm + libdevice

fix: partial fix to the blocks

fix: correct launch configuration

test: missing vars

chore: bump workspace

fix: cluster dims

fix: bump version

chore: bump
@avik-pal avik-pal force-pushed the ap/triton_integration branch from de7b9f4 to 3e2c89a Compare November 11, 2025 15:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants