Skip to content

Commit a270ce4

Browse files
committed
feat: julia api to access device properties [skip ci]
1 parent 5a3c3bc commit a270ce4

File tree

6 files changed

+117
-27
lines changed

6 files changed

+117
-27
lines changed

src/Compiler.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,8 @@ function __get_compile_options_and_kwargs(;
13951395
end
13961396

13971397
function compile_mlir(f, args; client=nothing, kwargs...)
1398-
backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
1398+
client = client !== nothing ? client : XLA.default_backend()
1399+
backend = XLA.platform_name(client)
13991400

14001401
if backend == "CUDA"
14011402
backend = "GPU"
@@ -1414,6 +1415,7 @@ function compile_mlir(f, args; client=nothing, kwargs...)
14141415
compile_options;
14151416
backend,
14161417
runtime=XLA.runtime(client),
1418+
client,
14171419
kwargs...,
14181420
)
14191421

@@ -1430,11 +1432,9 @@ end
14301432

14311433
const PartitionKA = Ref{Bool}(true)
14321434

1433-
const cubinChip = Ref{String}("sm_60")
1434-
const cubinFormat = Ref{String}("bin")
14351435
const cuindexBitWidth = Ref{Int}(32)
1436+
const cubinFormat = Ref{String}("bin")
14361437
const cuOptLevel = Ref{Int}(2)
1437-
const cuWarpSize = Ref{Int}(32)
14381438

14391439
# Wgatever the relevant highest version from our LLVM is within NVPTX.td
14401440
# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
@@ -1580,8 +1580,11 @@ function compile_mlir!(
15801580
backend="gpu",
15811581
runtime::Union{Val{:PJRT},Val{:IFRT}},
15821582
legalize_stablehlo_to_mhlo::Bool=false,
1583+
client=nothing,
15831584
kwargs...,
15841585
)
1586+
client = client !== nothing ? client : XLA.default_backend()
1587+
15851588
# Explicitly don't use block! to avoid creating a closure, which creates
15861589
# both compile-time and relocatability issues
15871590

@@ -1655,25 +1658,27 @@ function compile_mlir!(
16551658
else
16561659
jit = "lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce"
16571660
end
1658-
elseif DEBUG_KERNEL[]
1659-
curesulthandler = dlsym(
1660-
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
1661-
)
1662-
@assert curesulthandler !== nothing
1663-
curesulthandler = Base.reinterpret(UInt, curesulthandler)
1661+
else
16641662
kern = if is_raising
16651663
"lower-kernel{backend=cpu},symbol-dce,canonicalize"
16661664
else
16671665
"lower-kernel,canonicalize"
16681666
end
1669-
jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
1670-
else
1671-
kern = if is_raising
1672-
"lower-kernel{backend=cpu},symbol-dce,canonicalize"
1667+
1668+
device_properties = XLA.device_properties(XLA.default_device(client))
1669+
cubinChip = "sm_$(device_properties.major)$(device_properties.minor)"
1670+
1671+
if DEBUG_KERNEL[]
1672+
curesulthandler = dlsym(
1673+
Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult"
1674+
)
1675+
@assert curesulthandler !== nothing
1676+
curesulthandler = Base.reinterpret(UInt, curesulthandler)
1677+
extra_lowerjit_options = "debug=true cuResultHandlerPtr=$curesulthandler "
16731678
else
1674-
"lower-kernel,canonicalize"
1679+
extra_lowerjit_options = ""
16751680
end
1676-
jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
1681+
jit = "lower-jit{$(extra_lowerjit_options)cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce"
16771682
end
16781683

16791684
recognize_comms = true
@@ -3477,7 +3482,8 @@ function compile_xla(
34773482
context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0)
34783483
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
34793484

3480-
backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend())
3485+
client = client !== nothing ? client : XLA.default_backend()
3486+
backend = XLA.platform_name(client)
34813487

34823488
if backend == "CUDA"
34833489
backend = "GPU"
@@ -3498,6 +3504,7 @@ function compile_xla(
34983504
compile_options;
34993505
backend,
35003506
runtime=XLA.runtime(client),
3507+
client,
35013508
kwargs...,
35023509
)
35033510

src/xla/Device.jl

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ function device_kind end
1111
function default_memory end
1212
function memories end
1313
function is_addressable end
14+
function get_local_hardware_id end
1415

1516
"""
1617
device_ordinal(device::Device)
@@ -29,3 +30,78 @@ end
2930
function is_addressable(device::AbstractDevice)
3031
return device addressable_devices(client(device))
3132
end
33+
34+
# Keep in sync with API.cpp
35+
struct DeviceProperties
36+
total_global_mem::Csize_t
37+
shared_mem_per_block::Csize_t
38+
regs_per_block::Cint
39+
warp_size::Cint
40+
max_threads_per_block::Cint
41+
max_threads_dim::NTuple{3,Cint}
42+
max_grid_size::NTuple{3,Cint}
43+
clock_rate::Cint
44+
total_const_mem::Csize_t
45+
major::Cint
46+
minor::Cint
47+
multi_processor_count::Cint
48+
can_map_host_memory::Cint
49+
compute_mode::Cint
50+
l2_cache_size::Cint
51+
max_threads_per_multiprocessor::Cint
52+
end
53+
54+
const DEVICE_PROPERTIES_CACHE = Dict{Tuple{Int,String},DeviceProperties}()
55+
56+
"""
57+
device_properties(device::AbstractDevice)
58+
59+
Get a struct containing device properties. Which exact fields are populated relies on the
60+
underlying device implementation.
61+
"""
62+
function device_properties(device::AbstractDevice)
63+
pname = platform_name(client(device))
64+
local_hardware_id = get_local_hardware_id(device)
65+
66+
if haskey(DEVICE_PROPERTIES_CACHE, (local_hardware_id, pname))
67+
return DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)]
68+
end
69+
70+
jldevprops = Ref{DeviceProperties}()
71+
if pname == "cuda"
72+
GC.@preserve jldevprops begin
73+
@ccall MLIR.API.mlir_c.ReactantCudaDeviceGetProperties(
74+
jldevprops::Ptr{Cvoid}, local_hardware_id::Cint
75+
)::Cvoid
76+
end
77+
else
78+
@warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1
79+
end
80+
DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] = jldevprops[]
81+
return jldevprops[]
82+
end
83+
84+
function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties)
85+
return print(
86+
io,
87+
"""
88+
DeviceProperties
89+
----------------
90+
Total Global Mem: $(_format_bytes(props.total_global_mem))
91+
Shared Mem Per Block: $(_format_bytes(props.shared_mem_per_block))
92+
Regs Per Block: $(props.regs_per_block)
93+
Warp Size: $(props.warp_size)
94+
Max Threads Per Block: $(props.max_threads_per_block)
95+
Max Threads Dim: $(props.max_threads_dim)
96+
Max Grid Size: $(props.max_grid_size)
97+
Clock Rate: $(props.clock_rate)
98+
Total Const Mem: $(_format_bytes(props.total_const_mem))
99+
Version: $(VersionNumber(props.major, props.minor))
100+
Multi Processor Count: $(props.multi_processor_count)
101+
Can Map Host Memory: $(props.can_map_host_memory)
102+
Compute Mode: $(props.compute_mode)
103+
L2 Cache Size: $(props.l2_cache_size)
104+
Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor)
105+
""",
106+
)
107+
end

src/xla/IFRT/Device.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ function XLA.get_local_device_id(::Device)
3131
return error("Not implemented for ifrt devices")
3232
end
3333

34+
function XLA.get_local_hardware_id(::Device)
35+
GC.@preserve device begin
36+
return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId(
37+
device.device::Ptr{Cvoid}
38+
)::Cint
39+
end
40+
end
41+
3442
function XLA.default_memory(device::Device)
3543
GC.@preserve device begin
3644
return Memory(

src/xla/PJRT/Device.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ function XLA.get_local_device_id(device::Device)
3333
end
3434
end
3535

36+
function XLA.get_local_hardware_id(device::Device)
37+
GC.@preserve device begin
38+
return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalHardwareId(
39+
device.device::Ptr{Cvoid}
40+
)::Cint
41+
end
42+
end
43+
3644
function XLA.is_addressable(device::Device)
3745
GC.@preserve device begin
3846
return @ccall MLIR.API.mlir_c.pjrt_device_is_addressable(

src/xla/Stats.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ struct JLAllocatorStats
1313
peak_pool_bytes::Int64
1414
end
1515

16-
_format_bytes(x) = Base.format_bytes(x)
16+
_format_bytes(x) = x < 0 ? nothing : Base.format_bytes(x)
1717
_format_bytes(x::Nothing) = x
1818

1919
"""

src/xla/XLA.jl

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,6 @@ for runtime in (:PJRT, :IFRT)
234234
)
235235
state.clients["cuda"] = gpu
236236
state.default_client = gpu
237-
238-
# set values for cuda. This is being done here since we need cuda
239-
# to be initialized before we can use it. initializing the devices
240-
# implicitly initializes cuda.
241-
cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32
242-
cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32
243-
Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)"
244-
245-
Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32
246237
catch e
247238
println(stdout, e)
248239
end

0 commit comments

Comments
 (0)