Skip to content

Commit 41028a3

Browse files
authored
feat: HLO IR julia bindings (#1839)
feat: parse HloInstructions correctly feat: improved printing of HLO IR feat: device description feat: performance model for gpu chore: run fmt
1 parent 78cf63e commit 41028a3

File tree

7 files changed

+300
-29
lines changed

7 files changed

+300
-29
lines changed

src/xla/Device.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,38 @@ function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties)
101101
""",
102102
)
103103
end
104+
105+
# only for streaming executors like CUDA / ROCM
106+
mutable struct StreamExecutorDeviceDescription
107+
ptr::Ptr{Cvoid}
108+
109+
function StreamExecutorDeviceDescription(ptr::Ptr{Cvoid})
110+
@assert ptr != C_NULL
111+
return new(ptr)
112+
end
113+
end
114+
115+
function StreamExecutorDeviceDescription(device::AbstractDevice)
116+
panme = platform_name(client(device))
117+
local_hardware_id = get_local_hardware_id(device)
118+
119+
if panme == "cuda"
120+
return StreamExecutorDeviceDescription(
121+
@ccall MLIR.API.mlir_c.CudaGetStreamExecutorDeviceDescription(
122+
local_hardware_id::Int32
123+
)::Ptr{Cvoid}
124+
)
125+
else
126+
error("Unsupported platform: $(panme)")
127+
end
128+
end
129+
130+
function Base.show(io::IO, ::MIME"text/plain", props::StreamExecutorDeviceDescription)
131+
GC.@preserve props begin
132+
str = @ccall MLIR.API.mlir_c.deviceDescriptionToString(
133+
props.ptr::Ptr{Cvoid}
134+
)::Cstring
135+
end
136+
print(io, unsafe_string_and_free(str))
137+
return nothing
138+
end

src/xla/HloModule.jl

Lines changed: 0 additions & 28 deletions
This file was deleted.

src/xla/IR/Computation.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
mutable struct HloComputation
2+
ptr::Ptr{Cvoid}
3+
4+
function HloComputation(ptr::Ptr{Cvoid})
5+
@assert ptr != C_NULL
6+
return new(ptr)
7+
end
8+
end
9+
10+
function free_hlo_computation(hlo_computation)
11+
@ccall MLIR.API.mlir_c.freeHloComputation(hlo_computation.ptr::Ptr{Cvoid})::Cvoid
12+
end
13+
14+
function Base.getproperty(hlo_computation::HloComputation, sym::Symbol)
15+
if sym === :instructions
16+
return convert(Vector{HloInstruction}, hlo_computation)
17+
end
18+
return getfield(hlo_computation, sym)
19+
end
20+
21+
function Base.show(io::IO, hlo_computation::HloComputation)
22+
GC.@preserve hlo_computation begin
23+
str = @ccall MLIR.API.mlir_c.hloComputationToString(
24+
hlo_computation.ptr::Ptr{Cvoid}, _iobuffer_to_hlo_print_options(io)::Int32
25+
)::Cstring
26+
end
27+
print(io, unsafe_string_and_free(str))
28+
return nothing
29+
end
30+
31+
function Base.convert(::Type{Vector{HloInstruction}}, hlo_computation::HloComputation)
32+
num_instructions = @ccall MLIR.API.mlir_c.hloComputationInstructionCount(
33+
hlo_computation.ptr::Ptr{Cvoid}
34+
)::Int64
35+
hlo_instructions = Ref{NTuple{num_instructions,Ptr{Cvoid}}}()
36+
GC.@preserve hlo_computation hlo_instructions begin
37+
@ccall MLIR.API.mlir_c.hloComputationGetInstructionsPostOrder(
38+
hlo_computation.ptr::Ptr{Cvoid},
39+
num_instructions::Int64,
40+
hlo_instructions::Ptr{Ptr{Cvoid}},
41+
)::Cvoid
42+
end
43+
return [map(HloInstruction, hlo_instructions[])...]
44+
end

src/xla/IR/Instruction.jl

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
mutable struct HloInstruction
2+
ptr::Ptr{Cvoid}
3+
4+
function HloInstruction(ptr::Ptr{Cvoid})
5+
@assert ptr != C_NULL
6+
return new(ptr)
7+
end
8+
end
9+
10+
function free_hlo_instruction(hlo_instruction)
11+
@ccall MLIR.API.mlir_c.freeHloInstruction(hlo_instruction.ptr::Ptr{Cvoid})::Cvoid
12+
end
13+
14+
function Base.show(io::IO, hlo_instruction::HloInstruction)
15+
GC.@preserve hlo_instruction begin
16+
str = @ccall MLIR.API.mlir_c.hloInstructionToString(
17+
hlo_instruction.ptr::Ptr{Cvoid}, _iobuffer_to_hlo_print_options(io)::Int32
18+
)::Cstring
19+
end
20+
print(io, unsafe_string_and_free(str))
21+
return nothing
22+
end
23+
24+
function Base.getproperty(hlo_instruction::HloInstruction, sym::Symbol)
25+
if sym === :opcode
26+
return HloOpcode(
27+
@ccall MLIR.API.mlir_c.hloInstructionGetOpcode(
28+
hlo_instruction.ptr::Ptr{Cvoid}
29+
)::UInt8
30+
)
31+
end
32+
if sym === :to_apply
33+
@assert has_to_apply(hlo_instruction)
34+
return HloComputation(
35+
@ccall MLIR.API.mlir_c.hloInstructionGetToApply(
36+
hlo_instruction.ptr::Ptr{Cvoid}
37+
)::Ptr{Cvoid}
38+
)
39+
end
40+
if sym in (:fusion_kind, :fused_instructions_computation)
41+
@assert is_fusion_instruction(hlo_instruction)
42+
if sym === :fusion_kind
43+
return HloFusionKind(
44+
@ccall MLIR.API.mlir_c.hloInstructionGetFusionKind(
45+
hlo_instruction.ptr::Ptr{Cvoid}
46+
)::UInt8
47+
)
48+
else
49+
return HloComputation(
50+
@ccall MLIR.API.mlir_c.hloInstructionFusedInstructionsComputation(
51+
hlo_instruction.ptr::Ptr{Cvoid}
52+
)::Ptr{Cvoid}
53+
)
54+
end
55+
end
56+
return getfield(hlo_instruction, sym)
57+
end
58+
59+
function has_to_apply(hlo_instruction::HloInstruction)
60+
has_to_apply = @ccall MLIR.API.mlir_c.hloInstructionHasToApply(
61+
hlo_instruction.ptr::Ptr{Cvoid}
62+
)::UInt8
63+
return has_to_apply == 1
64+
end
65+
66+
function is_fusion_instruction(hlo_instruction::HloInstruction)
67+
is_fusion = @ccall MLIR.API.mlir_c.hloInstructionIsFusion(
68+
hlo_instruction.ptr::Ptr{Cvoid}
69+
)::UInt8
70+
return is_fusion == 1
71+
end
72+
73+
struct HloOpcode
74+
opcode::UInt8
75+
end
76+
77+
function Base.show(io::IO, hlo_opcode::HloOpcode)
78+
print(
79+
io,
80+
unsafe_string_and_free(
81+
@ccall MLIR.API.mlir_c.hloOpcodeToString(hlo_opcode.opcode::UInt8)::Cstring
82+
),
83+
)
84+
return nothing
85+
end
86+
87+
struct HloFusionKind
88+
kind::UInt8
89+
end
90+
91+
function Base.show(io::IO, fusion_kind::HloFusionKind)
92+
print(
93+
io,
94+
unsafe_string_and_free(
95+
@ccall MLIR.API.mlir_c.hloFusionKindToString(fusion_kind.kind::UInt8)::Cstring
96+
),
97+
)
98+
return nothing
99+
end

src/xla/IR/Module.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
mutable struct HloModule
2+
ptr::Ptr{Cvoid}
3+
4+
function HloModule(ptr::Ptr{Cvoid})
5+
@assert ptr != C_NULL
6+
return finalizer(free_hlo_module, new(ptr))
7+
end
8+
end
9+
10+
function free_hlo_module(hlo_module)
11+
@ccall MLIR.API.mlir_c.FreeHloModule(hlo_module.ptr::Ptr{Cvoid})::Cvoid
12+
end
13+
14+
function HloModule(mod::MLIR.IR.Module)
15+
return HloModule(
16+
@ccall MLIR.API.mlir_c.convertMlirModuleToHloModule(
17+
mod::MLIR.API.MlirModule
18+
)::Ptr{Cvoid}
19+
)
20+
end
21+
22+
function _iobuffer_to_hlo_print_options(io::IO)
23+
get(io, :compact, false) && return Int32(1) # ShortParsable
24+
get(io, :canonical, false) && return Int32(2) # Canonical
25+
get(io, :fingerprint, false) && return Int32(3) # Fingerprint
26+
get(io, :module_fingerprint, false) && return Int32(4) # ModuleFingerprint
27+
return Int32(0) # Default
28+
end
29+
30+
function Base.show(io::IO, hlo_module::HloModule)
31+
GC.@preserve hlo_module begin
32+
str = @ccall MLIR.API.mlir_c.HloModuleToString(
33+
hlo_module.ptr::Ptr{Cvoid}, _iobuffer_to_hlo_print_options(io)::Int32
34+
)::Cstring
35+
end
36+
print(io, unsafe_string_and_free(str))
37+
return nothing
38+
end
39+
40+
function Base.parse(::Type{HloModule}, str::AbstractString)
41+
return HloModule(
42+
@ccall MLIR.API.mlir_c.parseAndReturnUnverifiedHloModule(str::Cstring)::Ptr{Cvoid}
43+
)
44+
end
45+
46+
function Base.read(filename::AbstractString, ::Type{HloModule})
47+
return parse(HloModule, read(filename, String))
48+
end
49+
50+
function Base.getproperty(hlo_module::HloModule, sym::Symbol)
51+
if sym === :entry_computation
52+
return HloComputation(
53+
@ccall MLIR.API.mlir_c.hloModuleGetEntryComputation(
54+
hlo_module.ptr::Ptr{Cvoid}
55+
)::Ptr{Cvoid}
56+
)
57+
end
58+
return getfield(hlo_module, sym)
59+
end

src/xla/IR/PerformanceModel.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# currently only supported for CUDA and ROCM
2+
struct GPUPerformanceModel
3+
ptr::Ptr{Cvoid}
4+
5+
function GPUPerformanceModel(ptr::Ptr{Cvoid})
6+
@assert ptr != C_NULL
7+
return new(ptr)
8+
end
9+
end
10+
11+
function GPUPerformanceModel(
12+
mlir_context::MLIR.IR.Context, device_description::StreamExecutorDeviceDescription
13+
)
14+
return GPUPerformanceModel(
15+
@ccall MLIR.API.mlir_c.CreateGPUPerformanceModel(
16+
mlir_context::MLIR.API.MlirContext, device_description.ptr::Ptr{Cvoid}
17+
)::Ptr{Cvoid}
18+
)
19+
end
20+
21+
# Runs the analysis on the given HLO module.
22+
function (gpu_performance_model::GPUPerformanceModel)(hlo_module::HloModule)
23+
GC.@preserve hlo_module begin
24+
@ccall MLIR.API.mlir_c.RunAnalysisOnHloModule(
25+
gpu_performance_model.ptr::Ptr{Cvoid}, hlo_module.ptr::Ptr{Cvoid}
26+
)::Cvoid
27+
end
28+
return nothing
29+
end
30+
31+
function (gpu_performance_model::GPUPerformanceModel)(hlo_instruction::HloInstruction)
32+
return estimate_runtime_for_instruction(gpu_performance_model, hlo_instruction)
33+
end
34+
35+
## To keep in sync with JLEstimateRunTimeData in ReactantExtra/API.cpp
36+
struct EstimateRunTimeData
37+
flops::Int64
38+
bytes_read::Int64
39+
bytes_written::Int64
40+
read_time_ns::Int64
41+
write_time_ns::Int64
42+
compute_time_ns::Int64
43+
execution_time_ns::Int64
44+
end
45+
46+
function estimate_runtime_for_instruction(
47+
performance_model::GPUPerformanceModel, hlo_instruction::HloInstruction
48+
)
49+
data = Ref{EstimateRunTimeData}()
50+
GC.@preserve performance_model hlo_instruction data begin
51+
@ccall MLIR.API.mlir_c.EstimateRunTimeForInstruction(
52+
performance_model.ptr::Ptr{Cvoid},
53+
hlo_instruction.ptr::Ptr{Cvoid},
54+
data::Ptr{EstimateRunTimeData},
55+
)::Cvoid
56+
end
57+
return data[]
58+
end

src/xla/XLA.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ include("Future.jl")
3232
include("Buffer.jl")
3333
include("Stats.jl")
3434
include("Utils.jl")
35-
include("HloModule.jl")
3635
include("Memory.jl")
3736

37+
include("IR/Module.jl")
38+
include("IR/Instruction.jl")
39+
include("IR/Computation.jl")
40+
include("IR/PerformanceModel.jl")
41+
3842
include("PJRT/PJRT.jl")
3943

4044
include("IFRT/IFRT.jl")

0 commit comments

Comments
 (0)