Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/accelerators/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ module Accelerators

include("TPU.jl")
include("Metal.jl")
include("TT.jl")

end
65 changes: 65 additions & 0 deletions src/accelerators/TT.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module TT

using Reactant: Reactant
using Scratch: @get_scratch!
using Downloads: Downloads
using p7zip_jll: p7zip

const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)
const tt_pjrt_plugin_name = Ref{String}("pjrt_plugin_tt.so")

function __init__()
@static if Sys.islinux()
if !Reactant.precompiling() && has_tt()
setup_tt_pjrt_plugin!()
end
end
end

has_tt() = true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to figure out how to detect the devices automatically, but apart from that this should be good from my side as an experimental plugin. I'm not sure about the "TT" name everywhere though, bit too short and obscure, but that's how all the TensTorrent tools are called.


function setup_tt_pjrt_plugin!()
plugin_dir_from_env = get(ENV, "TT_PJRT_PLUGIN_DIR", nothing)
if plugin_dir_from_env !== nothing && ispath(plugin_dir_from_env)
tt_pjrt_plugin_dir[] = plugin_dir_from_env
else
tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_plugin_tt")
end
download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[])
return nothing
end

get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[]

function get_tt_pjrt_plugin_path()
return joinpath(get_tt_pjrt_plugin_dir(), tt_pjrt_plugin_name[])
end

function download_tt_pjrt_plugin_if_needed(dir=nothing)
dir === nothing && (dir = get_tt_pjrt_plugin_dir())
@assert dir !== nothing "tt_pjrt_plugin_dir is not set!"

tt_pjrt_plugin_path = joinpath(dir, tt_pjrt_plugin_name[])
if isfile(tt_pjrt_plugin_path)
@debug "TT PJRT plugin already found in '$(tt_pjrt_plugin_path)', nothing to do"
else
@debug "Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path)'"
mktempdir() do tmp_dir
zip_file_path = joinpath(tmp_dir, "pjrt-plugin-tt.zip")
wheel_url = if Sys.ARCH === :x86_64
"https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251113-cp311-cp311-linux_x86_64.whl"
else
error("Unsupported architecture: $(Sys.ARCH)")
end
@debug "Downloading TT PJRT plugin from '$(wheel_url)'"
Downloads.download(wheel_url, zip_file_path)
run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull))
data_dir = only(filter!(endswith(".data"), readdir(tmp_dir; join=true)))
# We need to move the entire `pjrt_plugin_tt` directory to the destination.
mv(joinpath(data_dir, "purelib", "pjrt_plugin_tt"), dir; force=true)
end
@assert isfile(tt_pjrt_plugin_path)
end
end

end # module TT
13 changes: 13 additions & 0 deletions src/xla/IFRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const tt_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:TTClient, :tt_client_count),
)
main_fn = Symbol(:MakeIFRTPJRT, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -219,6 +221,17 @@ function MakeIFRTPJRTMetalClient(;
)
end

function MakeIFRTPJRTTTClient(;
tt_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
return MakeIFRTPJRTClientViaPluginAPI(
tt_pjrt_plugin_path, "tt", "TT"; node_id, num_nodes, distributed_runtime_client
)
end

function MakeIFRTPJRTClientViaPluginAPI(
library_path::String,
device_type::String,
Expand Down
16 changes: 16 additions & 0 deletions src/xla/PJRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const tt_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:TTClient, :tt_client_count),
)
main_fn = Symbol(:Make, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -207,6 +209,20 @@ function MakeMetalClient(;
return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL")
end

function MakeTTClient(;
tt_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
@assert node_id == 0 "`PJRT.MakeTTClient` does not support node_id"
@assert num_nodes == 1 "`PJRT.MakeTTClient` does not support num_nodes > 1"
@assert distributed_runtime_client === nothing "`PJRT.MakeTTClient` does not support \
distributed_runtime_client"

return MakeClientUsingPluginAPI(tt_pjrt_plugin_path, "tt", "TT")
end

function MakeClientUsingPluginAPI(
library_path::String, device_type::String, client_name::String=uppercase(device_type)
)
Expand Down
35 changes: 35 additions & 0 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,41 @@ for runtime in (:PJRT, :IFRT)
catch e
println(stdout, e)
end
elseif Accelerators.TT.has_tt()
@debug "TT accelerator detected, setting it up"
try
if was_initialized && haskey(state.clients, "tt")
free_client(state.clients["tt"])
$(runtime).tt_client_count[] -= 1
end
# The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client.
tt_metal_runtime_root = get(ENV, "TT_METAL_RUNTIME_ROOT", nothing)
if isnothing(tt_metal_runtime_root)
tt_metal_path_in_wheel = joinpath(
dirname(Accelerators.TT.get_tt_pjrt_plugin_path()),
"tt-metal",
)
if ispath(tt_metal_path_in_wheel)
@debug "Setting environment variable 'TT_METAL_RUNTIME_ROOT' to '$(tt_metal_path_in_wheel)'"
ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel
else
error(
"`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it",
)
end
else
@debug "Environment variable 'TT_METAL_RUNTIME_ROOT' already set to to '$(tt_metal_runtime_root)'"
end

tt = $(runtime).TTClient(;
tt_pjrt_plugin_path=Accelerators.TT.get_tt_pjrt_plugin_path(),
common_kwargs...,
)
state.clients["tt"] = tt
state.default_client = tt
catch e
println(stdout, e)
end
elseif Reactant_jll.host_platform.tags["gpu"] != "none"
try
if was_initialized && haskey(state.clients, "cuda")
Expand Down
Loading