Skip to content

Commit 72852fc

Browse files
committed
Automatically download wheel for TT PJRT plugin
1 parent 596d4fa commit 72852fc

File tree

2 files changed

+44
-36
lines changed

2 files changed

+44
-36
lines changed

src/accelerators/TT.jl

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,61 +2,64 @@ module TT
22

33
using Reactant: Reactant
44
using Scratch: @get_scratch!
5-
using Downloads
5+
using Downloads: Downloads
6+
using p7zip_jll: p7zip
67

78
const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)
9+
const tt_pjrt_plugin_name = Ref{String}("pjrt_plugin_tt.so")
810

911
function __init__()
1012
@static if Sys.islinux()
11-
Reactant.precompiling() || setup_tt_pjrt_plugin!()
13+
if !Reactant.precompiling() && has_tt()
14+
setup_tt_pjrt_plugin!()
15+
end
1216
end
1317
end
1418

1519
has_tt() = true
1620

1721
function setup_tt_pjrt_plugin!()
18-
path_from_env = get(ENV, "TT_LIBRARY_PATH", nothing)
19-
if path_from_env !== nothing && ispath(path_from_env)
20-
tt_pjrt_plugin_dir[] = path_from_env
22+
plugin_dir_from_env = get(ENV, "TT_PJRT_PLUGIN_DIR", nothing)
23+
if plugin_dir_from_env !== nothing && ispath(plugin_dir_from_env)
24+
tt_pjrt_plugin_dir[] = plugin_dir_from_env
2125
else
22-
tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_tt_plugin")
26+
tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_plugin_tt")
2327
end
24-
# download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[])
28+
download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[])
2529
return nothing
2630
end
2731

2832
get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[]
2933

3034
function get_tt_pjrt_plugin_path()
31-
return joinpath(get_tt_pjrt_plugin_dir(), "pjrt_plugin_tt.so")
35+
return joinpath(get_tt_pjrt_plugin_dir(), tt_pjrt_plugin_name[])
3236
end
3337

34-
# function download_tt_pjrt_plugin_if_needed(path=nothing)
35-
# path === nothing && (path = get_tt_pjrt_plugin_dir())
36-
# @assert path !== nothing "tt_pjrt_plugin_dir is not set!"
37-
38-
# tt_pjrt_plugin_path = joinpath(path, "pjrt_plugin_tt_14.dylib")
39-
# if !isfile(tt_pjrt_plugin_path)
40-
# zip_file_path = joinpath(path, "pjrt-plugin-tt.zip")
41-
# tmp_dir = joinpath(path, "tmp")
42-
# Downloads.download(
43-
# if Sys.ARCH === :aarch64
44-
# "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_tt-0.1.1-py3-none-macosx_13_0_arm64.whl"
45-
# elseif Sys.ARCH === :x86_64
46-
# "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_tt-0.1.1-py3-none-macosx_10_14_x86_64.whl"
47-
# else
48-
# error("Unsupported architecture: $(Sys.ARCH)")
49-
# end,
50-
# zip_file_path,
51-
# )
52-
# run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
53-
# mv(
54-
# joinpath(tmp_dir, "jax_plugins", "tt_plugin", "pjrt_plugin_tt_14.dylib"),
55-
# tt_pjrt_plugin_path,
56-
# )
57-
# rm(tmp_dir; recursive=true)
58-
# rm(zip_file_path; recursive=true)
59-
# end
60-
# end
38+
function download_tt_pjrt_plugin_if_needed(dir=nothing)
39+
dir === nothing && (dir = get_tt_pjrt_plugin_dir())
40+
@assert dir !== nothing "tt_pjrt_plugin_dir is not set!"
41+
42+
tt_pjrt_plugin_path = joinpath(dir, tt_pjrt_plugin_name[])
43+
if isfile(tt_pjrt_plugin_path)
44+
@debug "TT PJRT plugin already found in '$(tt_pjrt_plugin_path)', nothing to do"
45+
else
46+
@debug "Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path)'"
47+
mktempdir() do tmp_dir
48+
zip_file_path = joinpath(tmp_dir, "pjrt-plugin-tt.zip")
49+
wheel_url = if Sys.ARCH === :x86_64
50+
"https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251113-cp311-cp311-linux_x86_64.whl"
51+
else
52+
error("Unsupported architecture: $(Sys.ARCH)")
53+
end
54+
@debug "Downloading TT PJRT plugin from '$(wheel_url)'"
55+
Downloads.download(wheel_url, zip_file_path)
56+
run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull))
57+
data_dir = only(filter!(endswith(".data"), readdir(tmp_dir; join=true)))
58+
# We need to move the entire `pjrt_plugin_tt` directory to the destination.
59+
mv(joinpath(data_dir, "purelib", "pjrt_plugin_tt"), dir; force=true)
60+
end
61+
@assert isfile(tt_pjrt_plugin_path)
62+
end
63+
end
6164

6265
end # module TT

src/xla/XLA.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,24 +227,29 @@ for runtime in (:PJRT, :IFRT)
227227
println(stdout, e)
228228
end
229229
elseif Accelerators.TT.has_tt()
230+
@debug "TT accelerator detected, setting it up"
230231
try
231232
if was_initialized && haskey(state.clients, "tt")
232233
XLA.free_client(state.clients["tt"])
233234
XLA.$(runtime).tt_client_count[] -= 1
234235
end
235236
# The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client.
236-
if isnothing(get(ENV, "TT_METAL_RUNTIME_ROOT", nothing))
237+
tt_metal_runtime_root = get(ENV, "TT_METAL_RUNTIME_ROOT", nothing)
238+
if isnothing(tt_metal_runtime_root)
237239
tt_metal_path_in_wheel = joinpath(
238240
dirname(Accelerators.TT.get_tt_pjrt_plugin_path()),
239241
"tt-metal",
240242
)
241243
if ispath(tt_metal_path_in_wheel)
244+
@debug "Setting environment variable 'TT_METAL_RUNTIME_ROOT' to '$(tt_metal_path_in_wheel)'"
242245
ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel
243246
else
244247
error(
245248
"`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it",
246249
)
247250
end
251+
else
252+
@debug "Environment variable 'TT_METAL_RUNTIME_ROOT' already set to to '$(tt_metal_runtime_root)'"
248253
end
249254

250255
tt = $(runtime).TTClient(;

0 commit comments

Comments
 (0)