Skip to content

Commit 7bee86b

Browse files
committed
feat: new triton_ext dialect
1 parent 4ea68f8 commit 7bee86b

File tree

5 files changed

+103
-49
lines changed

5 files changed

+103
-49
lines changed

deps/ReactantExtra/BUILD

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")
1+
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
22
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
33
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
44
load("@xla//tools/toolchains/cross_compile/cc:cc_toolchain_config.bzl", "cc_toolchain_config")
@@ -1435,6 +1435,24 @@ gentbl_cc_library(
14351435
],
14361436
)
14371437

1438+
gentbl_cc_library(
1439+
name = "TritonExtJLIncGen",
1440+
tbl_outs = [
1441+
(
1442+
[
1443+
"--generator=jl-op-defs",
1444+
"--disable-module-wrap=0",
1445+
],
1446+
"TritonExt.jl",
1447+
),
1448+
],
1449+
tblgen = "//:mlir-jl-tblgen",
1450+
td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td",
1451+
deps = [
1452+
"@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles",
1453+
],
1454+
)
1455+
14381456
gentbl_cc_library(
14391457
name = "TPUJLIncGen",
14401458
tbl_outs = [

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "52ae936cae8f7050adc26c4ed5e755200497dc86"
7+
ENZYMEXLA_COMMIT = "9867ac059bb2f312a1a6d559d2b41d8ba333a589"
88

99
ENZYMEXLA_SHA256 = ""
1010

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ for file in [
4242
"MPI.jl",
4343
"MemRef.jl",
4444
"SparseTensor.jl",
45+
"TritonExt.jl"
4546
]
4647
build_file(joinpath(src_dir, "mlir", "Dialects", file))
4748
end

src/mlir/Dialects/EnzymeXLA.jl

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -792,53 +792,6 @@ function stream2token(source::Value; result::IR.Type, location=Location())
792792
)
793793
end
794794

795-
function triton_call(
796-
gridx::Value,
797-
gridy::Value,
798-
gridz::Value,
799-
shmem::Value,
800-
inputs::Vector{Value};
801-
result_0::Vector{IR.Type},
802-
fn,
803-
backend_config=nothing,
804-
operand_layouts=nothing,
805-
result_layouts=nothing,
806-
arg_attrs=nothing,
807-
res_attrs=nothing,
808-
output_operand_aliases=nothing,
809-
xla_side_effect_free=nothing,
810-
location=Location(),
811-
)
812-
op_ty_results = IR.Type[result_0...,]
813-
operands = Value[gridx, gridy, gridz, shmem, inputs...]
814-
owned_regions = Region[]
815-
successors = Block[]
816-
attributes = NamedAttribute[namedattribute("fn", fn),]
817-
!isnothing(backend_config) &&
818-
push!(attributes, namedattribute("backend_config", backend_config))
819-
!isnothing(operand_layouts) &&
820-
push!(attributes, namedattribute("operand_layouts", operand_layouts))
821-
!isnothing(result_layouts) &&
822-
push!(attributes, namedattribute("result_layouts", result_layouts))
823-
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
824-
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
825-
!isnothing(output_operand_aliases) &&
826-
push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases))
827-
!isnothing(xla_side_effect_free) &&
828-
push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free))
829-
830-
return create_operation(
831-
"enzymexla.triton_call",
832-
location;
833-
operands,
834-
owned_regions,
835-
successors,
836-
attributes,
837-
results=op_ty_results,
838-
result_inference=false,
839-
)
840-
end
841-
842795
function wrap(
843796
operand::Value;
844797
result=nothing::Union{Nothing,IR.Type},

src/mlir/Dialects/TritonExt.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
module triton_ext
2+
using ...IR
3+
import ...IR:
4+
NamedAttribute,
5+
Value,
6+
Location,
7+
Block,
8+
Region,
9+
Attribute,
10+
create_operation,
11+
context,
12+
IndexType
13+
import ..Dialects: namedattribute, operandsegmentsizes
14+
import ...API
15+
16+
function call(
17+
gridx::Value,
18+
gridy::Value,
19+
gridz::Value,
20+
shmem::Value,
21+
inputs::Vector{Value};
22+
result_0::Vector{IR.Type},
23+
fn,
24+
backend_config=nothing,
25+
operand_layouts=nothing,
26+
result_layouts=nothing,
27+
arg_attrs=nothing,
28+
res_attrs=nothing,
29+
output_operand_aliases=nothing,
30+
xla_side_effect_free=nothing,
31+
location=Location(),
32+
)
33+
op_ty_results = IR.Type[result_0...,]
34+
operands = Value[gridx, gridy, gridz, shmem, inputs...]
35+
owned_regions = Region[]
36+
successors = Block[]
37+
attributes = NamedAttribute[namedattribute("fn", fn),]
38+
!isnothing(backend_config) &&
39+
push!(attributes, namedattribute("backend_config", backend_config))
40+
!isnothing(operand_layouts) &&
41+
push!(attributes, namedattribute("operand_layouts", operand_layouts))
42+
!isnothing(result_layouts) &&
43+
push!(attributes, namedattribute("result_layouts", result_layouts))
44+
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
45+
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
46+
!isnothing(output_operand_aliases) &&
47+
push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases))
48+
!isnothing(xla_side_effect_free) &&
49+
push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free))
50+
51+
return create_operation(
52+
"triton_ext.call",
53+
location;
54+
operands,
55+
owned_regions,
56+
successors,
57+
attributes,
58+
results=op_ty_results,
59+
result_inference=false,
60+
)
61+
end
62+
63+
function module_(; sym_name, bodyRegion::Region, location=Location())
64+
op_ty_results = IR.Type[]
65+
operands = Value[]
66+
owned_regions = Region[bodyRegion,]
67+
successors = Block[]
68+
attributes = NamedAttribute[namedattribute("sym_name", sym_name),]
69+
70+
return create_operation(
71+
"triton_ext.module",
72+
location;
73+
operands,
74+
owned_regions,
75+
successors,
76+
attributes,
77+
results=op_ty_results,
78+
result_inference=false,
79+
)
80+
end
81+
82+
end # triton_ext

0 commit comments

Comments
 (0)