From 2c169ef6b31009ff3ef3ce648415fc166b60aa06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 18:27:21 +0000 Subject: [PATCH 1/3] Use `REACTANT_DEFAULT_DEVICE` to set default device ID with env var --- src/xla/Client.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index ccf715c1ba..ce517b6231 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -13,4 +13,8 @@ function get_device end function get_addressable_device end function platform_name end -default_device(client::AbstractClient) = first(addressable_devices(client)) +function default_device(client::AbstractClient) + return addressable_devices(client)[something( + tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "1")), 1 + )] +end From 3f9be82a561352eb829caeb4ef841bc442668613 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 18:47:06 +0000 Subject: [PATCH 2/3] Interpret `REACTANT_DEFAULT_DEVICE` as 0-based --- src/xla/Client.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index ce517b6231..30233154ec 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -15,6 +15,8 @@ function platform_name end function default_device(client::AbstractClient) return addressable_devices(client)[something( - tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "1")), 1 + # `REACTANT_DEFAULT_DEVICE` is interpreted as 0-based. + tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "0")) + 1, + 1, )] end From af67d52a945dae1a2bfdcdc62b182c359f827786 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 19:26:11 +0000 Subject: [PATCH 3/3] Use global `const` `Reactant.XLA.DEFAULT_DEVICE` instead of env var --- src/xla/Client.jl | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/xla/Client.jl b/src/xla/Client.jl index 30233154ec..31925311b2 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -13,10 +13,13 @@ function get_device end function get_addressable_device end function platform_name end +""" + DEFAULT_DEVICE :: Ref{Int} + +0-based index of default device to use, by default 0 (first available device). +""" +const DEFAULT_DEVICE = Ref{Int}(0) + function default_device(client::AbstractClient) - return addressable_devices(client)[something( - # `REACTANT_DEFAULT_DEVICE` is interpreted as 0-based. - tryparse(Int, get(ENV, "REACTANT_DEFAULT_DEVICE", "0")) + 1, - 1, - )] + return addressable_devices(client)[DEFAULT_DEVICE[] + 1] end