diff --git a/src/xla/Client.jl b/src/xla/Client.jl index ccf715c1ba..31925311b2 100644 --- a/src/xla/Client.jl +++ b/src/xla/Client.jl @@ -13,4 +13,13 @@ function get_device end function get_addressable_device end function platform_name end -default_device(client::AbstractClient) = first(addressable_devices(client)) +""" + DEFAULT_DEVICE :: Ref{Int} + +0-based index of default device to use, by default 0 (first available device). +""" +const DEFAULT_DEVICE = Ref{Int}(0) + +function default_device(client::AbstractClient) + return addressable_devices(client)[DEFAULT_DEVICE[] + 1] +end