Skip to content

Commit 91e0d7c

Browse files
committed
Create TT client via plugin for TensTorrent devices
1 parent 81926ba commit 91e0d7c

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

src/accelerators/Accelerators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ module Accelerators
22

33
include("TPU.jl")
44
include("Metal.jl")
5+
include("TT.jl")
56

67
end

src/accelerators/TT.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
module TT
2+
3+
using Reactant: Reactant
4+
using Scratch: @get_scratch!
5+
using Downloads
6+
7+
const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)
8+
9+
function __init__()
10+
@static if Sys.islinux()
11+
Reactant.precompiling() || setup_tt_pjrt_plugin!()
12+
end
13+
end
14+
15+
has_tt() = true
16+
17+
function setup_tt_pjrt_plugin!()
18+
path_from_env = get(ENV, "TT_LIBRARY_PATH", nothing)
19+
if path_from_env !== nothing && ispath(path_from_env)
20+
tt_pjrt_plugin_dir[] = path_from_env
21+
else
22+
tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_tt_plugin")
23+
end
24+
# download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[])
25+
return nothing
26+
end
27+
28+
get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[]
29+
30+
function get_tt_pjrt_plugin_path()
31+
return joinpath(get_tt_pjrt_plugin_dir(), "pjrt_plugin_tt.so")
32+
end
33+
34+
# function download_tt_pjrt_plugin_if_needed(path=nothing)
35+
# path === nothing && (path = get_tt_pjrt_plugin_dir())
36+
# @assert path !== nothing "tt_pjrt_plugin_dir is not set!"
37+
38+
# tt_pjrt_plugin_path = joinpath(path, "pjrt_plugin_tt_14.dylib")
39+
# if !isfile(tt_pjrt_plugin_path)
40+
# zip_file_path = joinpath(path, "pjrt-plugin-tt.zip")
41+
# tmp_dir = joinpath(path, "tmp")
42+
# Downloads.download(
43+
# if Sys.ARCH === :aarch64
44+
# "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_tt-0.1.1-py3-none-macosx_13_0_arm64.whl"
45+
# elseif Sys.ARCH === :x86_64
46+
# "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_tt-0.1.1-py3-none-macosx_10_14_x86_64.whl"
47+
# else
48+
# error("Unsupported architecture: $(Sys.ARCH)")
49+
# end,
50+
# zip_file_path,
51+
# )
52+
# run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
53+
# mv(
54+
# joinpath(tmp_dir, "jax_plugins", "tt_plugin", "pjrt_plugin_tt_14.dylib"),
55+
# tt_pjrt_plugin_path,
56+
# )
57+
# rm(tmp_dir; recursive=true)
58+
# rm(zip_file_path; recursive=true)
59+
# end
60+
# end
61+
62+
end # module TT

src/xla/IFRT/Client.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,14 @@ const cpu_client_count = Ref(0)
115115
const cuda_client_count = Ref(0)
116116
const tpu_client_count = Ref(0)
117117
const metal_client_count = Ref(0)
118+
const tt_client_count = Ref(0)
118119

119120
for (backend, counter) in (
120121
(:CPUClient, :cpu_client_count),
121122
(:CUDAClient, :cuda_client_count),
122123
(:TPUClient, :tpu_client_count),
123124
(:MetalClient, :metal_client_count),
125+
(:TTClient, :tt_client_count),
124126
)
125127
main_fn = Symbol(:MakeIFRTPJRT, backend)
126128
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
@@ -219,6 +221,22 @@ function MakeIFRTPJRTMetalClient(;
219221
)
220222
end
221223

224+
function MakeIFRTPJRTTTClient(;
225+
rocm_pjrt_plugin_path::String,
226+
node_id::Integer=0,
227+
num_nodes::Integer=1,
228+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
229+
)
230+
return MakeIFRTPJRTClientViaPluginAPI(
231+
rocm_pjrt_plugin_path,
232+
"rocm",
233+
"TT";
234+
node_id,
235+
num_nodes,
236+
distributed_runtime_client,
237+
)
238+
end
239+
222240
function MakeIFRTPJRTClientViaPluginAPI(
223241
library_path::String,
224242
device_type::String,

src/xla/PJRT/Client.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ const cpu_client_count = Ref(0)
110110
const cuda_client_count = Ref(0)
111111
const tpu_client_count = Ref(0)
112112
const metal_client_count = Ref(0)
113+
const tt_client_count = Ref(0)
113114

114115
for (backend, counter) in (
115116
(:CPUClient, :cpu_client_count),
116117
(:CUDAClient, :cuda_client_count),
117118
(:TPUClient, :tpu_client_count),
118119
(:MetalClient, :metal_client_count),
120+
(:TTClient, :tt_client_count),
119121
)
120122
main_fn = Symbol(:Make, backend)
121123
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
@@ -207,6 +209,20 @@ function MakeMetalClient(;
207209
return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL")
208210
end
209211

212+
function MakeTTClient(;
213+
tt_pjrt_plugin_path::String,
214+
node_id::Integer=0,
215+
num_nodes::Integer=1,
216+
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
217+
)
218+
@assert node_id == 0 "`PJRT.MakeTTClient` does not support node_id"
219+
@assert num_nodes == 1 "`PJRT.MakeTTClient` does not support num_nodes > 1"
220+
@assert distributed_runtime_client === nothing "`PJRT.MakeTTClient` does not support \
221+
distributed_runtime_client"
222+
223+
return MakeClientUsingPluginAPI(tt_pjrt_plugin_path, "tt", "TT")
224+
end
225+
210226
function MakeClientUsingPluginAPI(
211227
library_path::String, device_type::String, client_name::String=uppercase(device_type)
212228
)

src/xla/XLA.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,32 @@ for runtime in (:PJRT, :IFRT)
226226
catch e
227227
println(stdout, e)
228228
end
229+
elseif Accelerators.TT.has_tt()
230+
try
231+
if was_initialized && haskey(state.clients, "tt")
232+
XLA.free_client(state.clients["tt"])
233+
XLA.$(runtime).tt_client_count[] -= 1
234+
end
235+
# The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client.
236+
if isnothing(get(ENV, "TT_METAL_RUNTIME_ROOT", nothing))
237+
tt_metal_path_in_wheel = joinpath(dirname(Accelerators.TT.get_tt_pjrt_plugin_path()), "tt-metal")
238+
if ispath(tt_metal_path_in_wheel)
239+
ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel
240+
else
241+
error("`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it")
242+
end
243+
end
244+
245+
tt = $(runtime).TTClient(
246+
;
247+
tt_pjrt_plugin_path=Accelerators.TT.get_tt_pjrt_plugin_path(),
248+
common_kwargs...
249+
)
250+
state.clients["tt"] = tt
251+
state.default_client = tt
252+
catch e
253+
println(stdout, e)
254+
end
229255
elseif Reactant_jll.host_platform.tags["gpu"] != "none"
230256
try
231257
if was_initialized && haskey(state.clients, "cuda")

0 commit comments

Comments
 (0)