diff --git a/src/accelerators/Accelerators.jl b/src/accelerators/Accelerators.jl index 88476922f6..4d54915ce0 100644 --- a/src/accelerators/Accelerators.jl +++ b/src/accelerators/Accelerators.jl @@ -2,5 +2,6 @@ module Accelerators include("TPU.jl") include("Metal.jl") +include("TT.jl") end diff --git a/src/accelerators/TT.jl b/src/accelerators/TT.jl new file mode 100644 index 0000000000..4d6cf5bce4 --- /dev/null +++ b/src/accelerators/TT.jl @@ -0,0 +1,65 @@ +module TT + +using Reactant: Reactant +using Scratch: @get_scratch! +using Downloads: Downloads +using p7zip_jll: p7zip + +const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing) +const tt_pjrt_plugin_name = Ref{String}("pjrt_plugin_tt.so") + +function __init__() + @static if Sys.islinux() + if !Reactant.precompiling() && has_tt() + setup_tt_pjrt_plugin!() + end + end +end + +has_tt() = true + +function setup_tt_pjrt_plugin!() + plugin_dir_from_env = get(ENV, "TT_PJRT_PLUGIN_DIR", nothing) + if plugin_dir_from_env !== nothing && ispath(plugin_dir_from_env) + tt_pjrt_plugin_dir[] = plugin_dir_from_env + else + tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_plugin_tt") + end + download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[]) + return nothing +end + +get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[] + +function get_tt_pjrt_plugin_path() + return joinpath(get_tt_pjrt_plugin_dir(), tt_pjrt_plugin_name[]) +end + +function download_tt_pjrt_plugin_if_needed(dir=nothing) + dir === nothing && (dir = get_tt_pjrt_plugin_dir()) + @assert dir !== nothing "tt_pjrt_plugin_dir is not set!" + + tt_pjrt_plugin_path = joinpath(dir, tt_pjrt_plugin_name[]) + if isfile(tt_pjrt_plugin_path) + @debug "TT PJRT plugin already found in '$(tt_pjrt_plugin_path)', nothing to do" + else + @debug "Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path)'" + mktempdir() do tmp_dir + zip_file_path = joinpath(tmp_dir, "pjrt-plugin-tt.zip") + wheel_url = if Sys.ARCH === :x86_64 + "https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251113-cp311-cp311-linux_x86_64.whl" + else + error("Unsupported architecture: $(Sys.ARCH)") + end + @debug "Downloading TT PJRT plugin from '$(wheel_url)'" + Downloads.download(wheel_url, zip_file_path) + run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull)) + data_dir = only(filter!(endswith(".data"), readdir(tmp_dir; join=true))) + # We need to move the entire `pjrt_plugin_tt` directory to the destination. + mv(joinpath(data_dir, "purelib", "pjrt_plugin_tt"), dir; force=true) + end + @assert isfile(tt_pjrt_plugin_path) + end +end + +end # module TT diff --git a/src/xla/IFRT/Client.jl b/src/xla/IFRT/Client.jl index e13a308893..9bf7402af2 100644 --- a/src/xla/IFRT/Client.jl +++ b/src/xla/IFRT/Client.jl @@ -115,12 +115,14 @@ const cpu_client_count = Ref(0) const cuda_client_count = Ref(0) const tpu_client_count = Ref(0) const metal_client_count = Ref(0) +const tt_client_count = Ref(0) for (backend, counter) in ( (:CPUClient, :cpu_client_count), (:CUDAClient, :cuda_client_count), (:TPUClient, :tpu_client_count), (:MetalClient, :metal_client_count), + (:TTClient, :tt_client_count), ) main_fn = Symbol(:MakeIFRTPJRT, backend) @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) @@ -219,6 +221,17 @@ function MakeIFRTPJRTMetalClient(; ) end +function MakeIFRTPJRTTTClient(; + tt_pjrt_plugin_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + return MakeIFRTPJRTClientViaPluginAPI( + tt_pjrt_plugin_path, "tt", "TT"; node_id, num_nodes, distributed_runtime_client + ) +end + function MakeIFRTPJRTClientViaPluginAPI( library_path::String, device_type::String, diff --git a/src/xla/PJRT/Client.jl b/src/xla/PJRT/Client.jl index c45aeac1a1..b4f3d7cb3c 100644 --- a/src/xla/PJRT/Client.jl +++ b/src/xla/PJRT/Client.jl @@ -110,12 +110,14 @@ const cpu_client_count = Ref(0) const cuda_client_count = Ref(0) const tpu_client_count = Ref(0) const metal_client_count = Ref(0) +const tt_client_count = Ref(0) for (backend, counter) in ( (:CPUClient, :cpu_client_count), (:CUDAClient, :cuda_client_count), (:TPUClient, :tpu_client_count), (:MetalClient, :metal_client_count), + (:TTClient, :tt_client_count), ) main_fn = Symbol(:Make, backend) @eval function $(backend)(args...; checkcount::Bool=true, kwargs...) @@ -207,6 +209,20 @@ function MakeMetalClient(; return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL") end +function MakeTTClient(; + tt_pjrt_plugin_path::String, + node_id::Integer=0, + num_nodes::Integer=1, + distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing, +) + @assert node_id == 0 "`PJRT.MakeTTClient` does not support node_id" + @assert num_nodes == 1 "`PJRT.MakeTTClient` does not support num_nodes > 1" + @assert distributed_runtime_client === nothing "`PJRT.MakeTTClient` does not support \ + distributed_runtime_client" + + return MakeClientUsingPluginAPI(tt_pjrt_plugin_path, "tt", "TT") +end + function MakeClientUsingPluginAPI( library_path::String, device_type::String, client_name::String=uppercase(device_type) ) diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index bc01ab11bb..5fa34457b1 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -226,6 +226,41 @@ for runtime in (:PJRT, :IFRT) catch e println(stdout, e) end + elseif Accelerators.TT.has_tt() + @debug "TT accelerator detected, setting it up" + try + if was_initialized && haskey(state.clients, "tt") + free_client(state.clients["tt"]) + $(runtime).tt_client_count[] -= 1 + end + # The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client. + tt_metal_runtime_root = get(ENV, "TT_METAL_RUNTIME_ROOT", nothing) + if isnothing(tt_metal_runtime_root) + tt_metal_path_in_wheel = joinpath( + dirname(Accelerators.TT.get_tt_pjrt_plugin_path()), + "tt-metal", + ) + if ispath(tt_metal_path_in_wheel) + @debug "Setting environment variable 'TT_METAL_RUNTIME_ROOT' to '$(tt_metal_path_in_wheel)'" + ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel + else + error( + "`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it", + ) + end + else + @debug "Environment variable 'TT_METAL_RUNTIME_ROOT' already set to to '$(tt_metal_runtime_root)'" + end + + tt = $(runtime).TTClient(; + tt_pjrt_plugin_path=Accelerators.TT.get_tt_pjrt_plugin_path(), + common_kwargs..., + ) + state.clients["tt"] = tt + state.default_client = tt + catch e + println(stdout, e) + end elseif Reactant_jll.host_platform.tags["gpu"] != "none" try if was_initialized && haskey(state.clients, "cuda")