Skip to content

Commit 30976c0

Browse files
committed
feat: new triton_ext dialect
1 parent e29cd0a commit 30976c0

File tree

4 files changed

+101
-47
lines changed

4 files changed

+101
-47
lines changed

deps/ReactantExtra/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,6 +1438,24 @@ gentbl_cc_library(
14381438
],
14391439
)
14401440

1441+
gentbl_cc_library(
1442+
name = "TritonExtJLIncGen",
1443+
tbl_outs = [
1444+
(
1445+
[
1446+
"--generator=jl-op-defs",
1447+
"--disable-module-wrap=0",
1448+
],
1449+
"TritonExt.jl",
1450+
),
1451+
],
1452+
tblgen = "//:mlir-jl-tblgen",
1453+
td_file = "@enzyme_ad//src/enzyme_ad/jax:Dialect/TritonExt/Ops.td",
1454+
deps = [
1455+
"@enzyme_ad//src/enzyme_ad/jax:TritonExtDialectTdFiles",
1456+
],
1457+
)
1458+
14411459
gentbl_cc_library(
14421460
name = "TPUJLIncGen",
14431461
tbl_outs = [

deps/ReactantExtra/make-bindings.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ for file in [
4242
"MPI.jl",
4343
"MemRef.jl",
4444
"SparseTensor.jl",
45+
"TritonExt.jl"
4546
]
4647
build_file(joinpath(src_dir, "mlir", "Dialects", file))
4748
end

src/mlir/Dialects/EnzymeXLA.jl

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -837,53 +837,6 @@ function stream2token(source::Value; result::IR.Type, location=Location())
837837
)
838838
end
839839

840-
function triton_call(
841-
gridx::Value,
842-
gridy::Value,
843-
gridz::Value,
844-
shmem::Value,
845-
inputs::Vector{Value};
846-
result_0::Vector{IR.Type},
847-
fn,
848-
backend_config=nothing,
849-
operand_layouts=nothing,
850-
result_layouts=nothing,
851-
arg_attrs=nothing,
852-
res_attrs=nothing,
853-
output_operand_aliases=nothing,
854-
xla_side_effect_free=nothing,
855-
location=Location(),
856-
)
857-
op_ty_results = IR.Type[result_0...,]
858-
operands = Value[gridx, gridy, gridz, shmem, inputs...]
859-
owned_regions = Region[]
860-
successors = Block[]
861-
attributes = NamedAttribute[namedattribute("fn", fn),]
862-
!isnothing(backend_config) &&
863-
push!(attributes, namedattribute("backend_config", backend_config))
864-
!isnothing(operand_layouts) &&
865-
push!(attributes, namedattribute("operand_layouts", operand_layouts))
866-
!isnothing(result_layouts) &&
867-
push!(attributes, namedattribute("result_layouts", result_layouts))
868-
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
869-
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
870-
!isnothing(output_operand_aliases) &&
871-
push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases))
872-
!isnothing(xla_side_effect_free) &&
873-
push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free))
874-
875-
return create_operation(
876-
"enzymexla.triton_call",
877-
location;
878-
operands,
879-
owned_regions,
880-
successors,
881-
attributes,
882-
results=op_ty_results,
883-
result_inference=false,
884-
)
885-
end
886-
887840
function wrap(
888841
operand::Value;
889842
result=nothing::Union{Nothing,IR.Type},

src/mlir/Dialects/TritonExt.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
module triton_ext
2+
using ...IR
3+
import ...IR:
4+
NamedAttribute,
5+
Value,
6+
Location,
7+
Block,
8+
Region,
9+
Attribute,
10+
create_operation,
11+
context,
12+
IndexType
13+
import ..Dialects: namedattribute, operandsegmentsizes
14+
import ...API
15+
16+
function call(
17+
gridx::Value,
18+
gridy::Value,
19+
gridz::Value,
20+
shmem::Value,
21+
inputs::Vector{Value};
22+
result_0::Vector{IR.Type},
23+
fn,
24+
backend_config=nothing,
25+
operand_layouts=nothing,
26+
result_layouts=nothing,
27+
arg_attrs=nothing,
28+
res_attrs=nothing,
29+
output_operand_aliases=nothing,
30+
xla_side_effect_free=nothing,
31+
location=Location(),
32+
)
33+
op_ty_results = IR.Type[result_0...,]
34+
operands = Value[gridx, gridy, gridz, shmem, inputs...]
35+
owned_regions = Region[]
36+
successors = Block[]
37+
attributes = NamedAttribute[namedattribute("fn", fn),]
38+
!isnothing(backend_config) &&
39+
push!(attributes, namedattribute("backend_config", backend_config))
40+
!isnothing(operand_layouts) &&
41+
push!(attributes, namedattribute("operand_layouts", operand_layouts))
42+
!isnothing(result_layouts) &&
43+
push!(attributes, namedattribute("result_layouts", result_layouts))
44+
!isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs))
45+
!isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs))
46+
!isnothing(output_operand_aliases) &&
47+
push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases))
48+
!isnothing(xla_side_effect_free) &&
49+
push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free))
50+
51+
return create_operation(
52+
"triton_ext.call",
53+
location;
54+
operands,
55+
owned_regions,
56+
successors,
57+
attributes,
58+
results=op_ty_results,
59+
result_inference=false,
60+
)
61+
end
62+
63+
function module_(; sym_name, bodyRegion::Region, location=Location())
64+
op_ty_results = IR.Type[]
65+
operands = Value[]
66+
owned_regions = Region[bodyRegion,]
67+
successors = Block[]
68+
attributes = NamedAttribute[namedattribute("sym_name", sym_name),]
69+
70+
return create_operation(
71+
"triton_ext.module",
72+
location;
73+
operands,
74+
owned_regions,
75+
successors,
76+
attributes,
77+
results=op_ty_results,
78+
result_inference=false,
79+
)
80+
end
81+
82+
end # triton_ext

0 commit comments

Comments
 (0)