From 596d4faf8e799b10c7e45825be5168fa7b333756 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Wed, 12 Nov 2025 18:46:22 +0000 Subject: [PATCH 1/3] Create TT client via plugin for TensTorrent devices --- src/accelerators/Accelerators.jl | 1 + src/accelerators/TT.jl | 62 ++++++++++++++++++++++++++++++++ src/xla/IFRT/Client.jl | 13 +++++++ src/xla/PJRT/Client.jl | 16 +++++++++ src/xla/XLA.jl | 30 ++++++++++++++++ 5 files changed, 122 insertions(+) create mode 100644 src/accelerators/TT.jl diff --git a/src/accelerators/Accelerators.jl b/src/accelerators/Accelerators.jl index 88476922f6..4d54915ce0 100644 --- a/src/accelerators/Accelerators.jl +++ b/src/accelerators/Accelerators.jl @@ -2,5 +2,6 @@ module Accelerators include("TPU.jl") include("Metal.jl") +include("TT.jl") end diff --git a/src/accelerators/TT.jl b/src/accelerators/TT.jl new file mode 100644 index 0000000000..59636977b0 --- /dev/null +++ b/src/accelerators/TT.jl @@ -0,0 +1,62 @@ +module TT + +using Reactant: Reactant +using Scratch: @get_scratch! +using Downloads + +const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing) + +function __init__() + @static if Sys.islinux() + Reactant.precompiling() || setup_tt_pjrt_plugin!() + end +end + +has_tt() = true + +function setup_tt_pjrt_plugin!() + path_from_env = get(ENV, "TT_LIBRARY_PATH", nothing) + if path_from_env !== nothing && ispath(path_from_env) + tt_pjrt_plugin_dir[] = path_from_env + else + tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_tt_plugin") + end + # download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[]) + return nothing +end + +get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[] + +function get_tt_pjrt_plugin_path() + return joinpath(get_tt_pjrt_plugin_dir(), "pjrt_plugin_tt.so") +end + +# function download_tt_pjrt_plugin_if_needed(path=nothing) +# path === nothing && (path = get_tt_pjrt_plugin_dir()) +# @assert path !== nothing "tt_pjrt_plugin_dir is not set!" + +# tt_pjrt_plugin_path = joinpath(path, "pjrt_plugin_tt_14.dylib") +# if !isfile(tt_pjrt_plugin_path) +# zip_file_path = joinpath(path, "pjrt-plugin-tt.zip") +# tmp_dir = joinpath(path, "tmp") +# Downloads.download( +# if Sys.ARCH === :aarch64 +# "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_tt-0.1.1-py3-none-macosx_13_0_arm64.whl" +# elseif Sys.ARCH === :x86_64 +# "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_tt-0.1.1-py3-none-macosx_10_14_x86_64.whl" +# else +# error("Unsupported architecture: $(Sys.ARCH)") +# end, +# zip_file_path, +# ) +# run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`) +# mv( +# joinpath(tmp_dir, "jax_plugins", "tt_plugin", "pjrt_plugin_tt_14.dylib"), +# tt_pjrt_plugin_path, +# ) +# rm(tmp_dir; recursive=true) +# rm(zip_file_path; recursive=true) +# end +# end + +end # module TT diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl index e13a308893..9bf7402af2 100644 --- a/src/xla/IFRT/Client.jl +++ b/src/xla/IFRT/Client.jl @@ -115,12 +115,14 @@ const cpu_client_count = Ref(0) const cuda_client_count = Ref(0) const tpu_client_count = Ref(0) const metal_client_count = Ref(0) +const tt_client_count = Ref(0) for (backend, counter) in ( (:CPUClient, :cpu_client_count), (:CUDAClient, :cuda_client_count), (:TPUClient, :tpu_client_count), (:MetalClient, :metal_client_count), + (:TTClient, :tt_client_count), ) main_fn = Symbol(:MakeIFRTPJRT, backend) @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) @@ -219,6 +221,17 @@ function MakeIFRTPJRTMetalClient(; ) end +function MakeIFRTPJRTTTClient(; + tt_pjrt_plugin_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + return MakeIFRTPJRTClientViaPluginAPI( + tt_pjrt_plugin_path, "tt", "TT"; node_id, num_nodes, distributed_runtime_client + ) +end + function MakeIFRTPJRTClientViaPluginAPI( library_path::String, device_type::String, diff --git a/src/xla/PJRT/Client.jl b/src/xla/PJRT/Client.jl index c45aeac1a1..b4f3d7cb3c 100644 --- a/src/xla/PJRT/Client.jl +++ b/src/xla/PJRT/Client.jl @@ -110,12 +110,14 @@ const cpu_client_count = Ref(0) const cuda_client_count = Ref(0) const tpu_client_count = Ref(0) const metal_client_count = Ref(0) +const tt_client_count = Ref(0) for (backend, counter) in ( (:CPUClient, :cpu_client_count), (:CUDAClient, :cuda_client_count), (:TPUClient, :tpu_client_count), (:MetalClient, :metal_client_count), + (:TTClient, :tt_client_count), ) main_fn = Symbol(:Make, backend) @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) @@ -207,6 +209,20 @@ function MakeMetalClient(; return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL") end +function MakeTTClient(; + tt_pjrt_plugin_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + @assert node_id == 0 "`PJRT.MakeTTClient` does not support node_id" + @assert num_nodes == 1 "`PJRT.MakeTTClient` does not support num_nodes > 1" + @assert distributed_runtime_client === nothing "`PJRT.MakeTTClient` does not support \ + distributed_runtime_client" + + return MakeClientUsingPluginAPI(tt_pjrt_plugin_path, "tt", "TT") +end + function MakeClientUsingPluginAPI( library_path::String, device_type::String, client_name::String=uppercase(device_type) ) diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index bc01ab11bb..f2aa405a0d 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -226,6 +226,36 @@ for runtime in (:PJRT, :IFRT) catch e println(stdout, e) end + elseif Accelerators.TT.has_tt() + try + if was_initialized && haskey(state.clients, "tt") + XLA.free_client(state.clients["tt"]) + XLA.$(runtime).tt_client_count[] -= 1 + end + # The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client. + if isnothing(get(ENV, "TT_METAL_RUNTIME_ROOT", nothing)) + tt_metal_path_in_wheel = joinpath( + dirname(Accelerators.TT.get_tt_pjrt_plugin_path()), + "tt-metal", + ) + if ispath(tt_metal_path_in_wheel) + ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel + else + error( + "`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it", + ) + end + end + + tt = $(runtime).TTClient(; + tt_pjrt_plugin_path=Accelerators.TT.get_tt_pjrt_plugin_path(), + common_kwargs..., + ) + state.clients["tt"] = tt + state.default_client = tt + catch e + println(stdout, e) + end elseif Reactant_jll.host_platform.tags["gpu"] != "none" try if was_initialized && haskey(state.clients, "cuda") From 72852fcbb961203465e5d0a5a5aa3e6da6e4a529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 12:50:49 +0000 Subject: [PATCH 2/3] Automatically download wheel for TT PJRT plugin --- src/accelerators/TT.jl | 73 ++++++++++++++++++++++-------------------- src/xla/XLA.jl | 7 +++- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/accelerators/TT.jl b/src/accelerators/TT.jl index 59636977b0..4d6cf5bce4 100644 --- a/src/accelerators/TT.jl +++ b/src/accelerators/TT.jl @@ -2,61 +2,64 @@ module TT using Reactant: Reactant using Scratch: @get_scratch! -using Downloads +using Downloads: Downloads +using p7zip_jll: p7zip const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing) +const tt_pjrt_plugin_name = Ref{String}("pjrt_plugin_tt.so") function __init__() @static if Sys.islinux() - Reactant.precompiling() || setup_tt_pjrt_plugin!() + if !Reactant.precompiling() && has_tt() + setup_tt_pjrt_plugin!() + end end end has_tt() = true function setup_tt_pjrt_plugin!() - path_from_env = get(ENV, "TT_LIBRARY_PATH", nothing) - if path_from_env !== nothing && ispath(path_from_env) - tt_pjrt_plugin_dir[] = path_from_env + plugin_dir_from_env = get(ENV, "TT_PJRT_PLUGIN_DIR", nothing) + if plugin_dir_from_env !== nothing && ispath(plugin_dir_from_env) + tt_pjrt_plugin_dir[] = plugin_dir_from_env else - tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_tt_plugin") + tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_plugin_tt") end - # download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[]) + download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[]) return nothing end get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[] function get_tt_pjrt_plugin_path() - return joinpath(get_tt_pjrt_plugin_dir(), "pjrt_plugin_tt.so") + return joinpath(get_tt_pjrt_plugin_dir(), tt_pjrt_plugin_name[]) end -# function download_tt_pjrt_plugin_if_needed(path=nothing) -# path === nothing && (path = get_tt_pjrt_plugin_dir()) -# @assert path !== nothing "tt_pjrt_plugin_dir is not set!" - -# tt_pjrt_plugin_path = joinpath(path, "pjrt_plugin_tt_14.dylib") -# if !isfile(tt_pjrt_plugin_path) -# zip_file_path = joinpath(path, "pjrt-plugin-tt.zip") -# tmp_dir = joinpath(path, "tmp") -# Downloads.download( -# if Sys.ARCH === :aarch64 -# "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_tt-0.1.1-py3-none-macosx_13_0_arm64.whl" -# elseif Sys.ARCH === :x86_64 -# "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_tt-0.1.1-py3-none-macosx_10_14_x86_64.whl" -# else -# error("Unsupported architecture: $(Sys.ARCH)") -# end, -# zip_file_path, -# ) -# run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`) -# mv( -# joinpath(tmp_dir, "jax_plugins", "tt_plugin", "pjrt_plugin_tt_14.dylib"), -# tt_pjrt_plugin_path, -# ) -# rm(tmp_dir; recursive=true) -# rm(zip_file_path; recursive=true) -# end -# end +function download_tt_pjrt_plugin_if_needed(dir=nothing) + dir === nothing && (dir = get_tt_pjrt_plugin_dir()) + @assert dir !== nothing "tt_pjrt_plugin_dir is not set!" + + tt_pjrt_plugin_path = joinpath(dir, tt_pjrt_plugin_name[]) + if isfile(tt_pjrt_plugin_path) + @debug "TT PJRT plugin already found in '$(tt_pjrt_plugin_path)', nothing to do" + else + @debug "Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path)'" + mktempdir() do tmp_dir + zip_file_path = joinpath(tmp_dir, "pjrt-plugin-tt.zip") + wheel_url = if Sys.ARCH === :x86_64 + "https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251113-cp311-cp311-linux_x86_64.whl" + else + error("Unsupported architecture: $(Sys.ARCH)") + end + @debug "Downloading TT PJRT plugin from '$(wheel_url)'" + Downloads.download(wheel_url, zip_file_path) + run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull)) + data_dir = only(filter!(endswith(".data"), readdir(tmp_dir; join=true))) + # We need to move the entire `pjrt_plugin_tt` directory to the destination. + mv(joinpath(data_dir, "purelib", "pjrt_plugin_tt"), dir; force=true) + end + @assert isfile(tt_pjrt_plugin_path) + end +end end # module TT diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index f2aa405a0d..978b31623f 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -227,24 +227,29 @@ for runtime in (:PJRT, :IFRT) println(stdout, e) end elseif Accelerators.TT.has_tt() + @debug "TT accelerator detected, setting it up" try if was_initialized && haskey(state.clients, "tt") XLA.free_client(state.clients["tt"]) XLA.$(runtime).tt_client_count[] -= 1 end # The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client. - if isnothing(get(ENV, "TT_METAL_RUNTIME_ROOT", nothing)) + tt_metal_runtime_root = get(ENV, "TT_METAL_RUNTIME_ROOT", nothing) + if isnothing(tt_metal_runtime_root) tt_metal_path_in_wheel = joinpath( dirname(Accelerators.TT.get_tt_pjrt_plugin_path()), "tt-metal", ) if ispath(tt_metal_path_in_wheel) + @debug "Setting environment variable 'TT_METAL_RUNTIME_ROOT' to '$(tt_metal_path_in_wheel)'" ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel else error( "`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it", ) end + else + @debug "Environment variable 'TT_METAL_RUNTIME_ROOT' already set to to '$(tt_metal_runtime_root)'" end tt = $(runtime).TTClient(; From c819a82f2cff972024559a6a63336edd8592f759 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mos=C3=A8=20Giordano?= Date: Thu, 13 Nov 2025 14:00:35 +0000 Subject: [PATCH 3/3] Remove self-qualified accesses --- src/xla/XLA.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 978b31623f..5fa34457b1 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -230,8 +230,8 @@ for runtime in (:PJRT, :IFRT) @debug "TT accelerator detected, setting it up" try if was_initialized && haskey(state.clients, "tt") - XLA.free_client(state.clients["tt"]) - XLA.$(runtime).tt_client_count[] -= 1 + free_client(state.clients["tt"]) + $(runtime).tt_client_count[] -= 1 end # The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client. tt_metal_runtime_root = get(ENV, "TT_METAL_RUNTIME_ROOT", nothing)