@@ -1395,7 +1395,8 @@ function __get_compile_options_and_kwargs(;
13951395end
13961396
13971397function compile_mlir (f, args; client= nothing , kwargs... )
1398- backend = XLA. platform_name (client != = nothing ? client : XLA. default_backend ())
1398+ client = client != = nothing ? client : XLA. default_backend ()
1399+ backend = XLA. platform_name (client)
13991400
14001401 if backend == " CUDA"
14011402 backend = " GPU"
@@ -1414,6 +1415,7 @@ function compile_mlir(f, args; client=nothing, kwargs...)
14141415 compile_options;
14151416 backend,
14161417 runtime= XLA. runtime (client),
1418+ client,
14171419 kwargs... ,
14181420 )
14191421
@@ -1430,11 +1432,9 @@ end
14301432
14311433const PartitionKA = Ref {Bool} (true )
14321434
1433- const cubinChip = Ref {String} (" sm_60" )
1434- const cubinFormat = Ref {String} (" bin" )
14351435const cuindexBitWidth = Ref {Int} (32 )
1436+ const cubinFormat = Ref {String} (" bin" )
14361437const cuOptLevel = Ref {Int} (2 )
1437- const cuWarpSize = Ref {Int} (32 )
14381438
14391439# Wgatever the relevant highest version from our LLVM is within NVPTX.td
14401440# Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684
@@ -1580,8 +1580,11 @@ function compile_mlir!(
15801580 backend= " gpu" ,
15811581 runtime:: Union{Val{:PJRT},Val{:IFRT}} ,
15821582 legalize_stablehlo_to_mhlo:: Bool = false ,
1583+ client= nothing ,
15831584 kwargs... ,
15841585)
1586+ client = client != = nothing ? client : XLA. default_backend ()
1587+
15851588 # Explicitly don't use block! to avoid creating a closure, which creates
15861589 # both compile-time and relocatability issues
15871590
@@ -1655,25 +1658,27 @@ function compile_mlir!(
16551658 else
16561659 jit = " lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce"
16571660 end
1658- elseif DEBUG_KERNEL[]
1659- curesulthandler = dlsym (
1660- Reactant_jll. libReactantExtra_handle, " ReactantHandleCuResult"
1661- )
1662- @assert curesulthandler != = nothing
1663- curesulthandler = Base. reinterpret (UInt, curesulthandler)
1661+ else
16641662 kern = if is_raising
16651663 " lower-kernel{backend=cpu},symbol-dce,canonicalize"
16661664 else
16671665 " lower-kernel,canonicalize"
16681666 end
1669- jit = " lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures ()) run_init=true toolkitPath=$toolkit },symbol-dce"
1670- else
1671- kern = if is_raising
1672- " lower-kernel{backend=cpu},symbol-dce,canonicalize"
1667+
1668+ device_properties = XLA. device_properties (XLA. default_device (client))
1669+ cubinChip = " sm_$(device_properties. major)$(device_properties. minor) "
1670+
1671+ if DEBUG_KERNEL[]
1672+ curesulthandler = dlsym (
1673+ Reactant_jll. libReactantExtra_handle, " ReactantHandleCuResult"
1674+ )
1675+ @assert curesulthandler != = nothing
1676+ curesulthandler = Base. reinterpret (UInt, curesulthandler)
1677+ extra_lowerjit_options = " debug=true cuResultHandlerPtr=$curesulthandler "
16731678 else
1674- " lower-kernel,canonicalize "
1679+ extra_lowerjit_options = " "
16751680 end
1676- jit = " lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth =$(cuindexBitWidth []) cubinFormat =$(cubinFormat []) cubinChip=$(cubinChip[] ) cubinFeatures=$(cubinFeatures ()) run_init=true toolkitPath=$toolkit },symbol-dce"
1681+ jit = " lower-jit{$(extra_lowerjit_options) cuOptLevel=$(cuOptLevel[]) cubinFormat =$(cubinFormat []) indexBitWidth =$(cuindexBitWidth []) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures ()) run_init=true toolkitPath=$toolkit },symbol-dce"
16771682 end
16781683
16791684 recognize_comms = true
@@ -3477,7 +3482,8 @@ function compile_xla(
34773482 context_gc_vector[ctx] = Vector {Union{TracedRArray,TracedRNumber}} (undef, 0 )
34783483 @ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
34793484
3480- backend = XLA. platform_name (client != = nothing ? client : XLA. default_backend ())
3485+ client = client != = nothing ? client : XLA. default_backend ()
3486+ backend = XLA. platform_name (client)
34813487
34823488 if backend == " CUDA"
34833489 backend = " GPU"
@@ -3498,6 +3504,7 @@ function compile_xla(
34983504 compile_options;
34993505 backend,
35003506 runtime= XLA. runtime (client),
3507+ client,
35013508 kwargs... ,
35023509 )
35033510
0 commit comments