@@ -2,61 +2,64 @@ module TT
22
33using Reactant: Reactant
44using Scratch: @get_scratch!
5- using Downloads
5+ using Downloads: Downloads
6+ using p7zip_jll: p7zip
67
78const tt_pjrt_plugin_dir = Ref {Union{Nothing,String}} (nothing )
9+ const tt_pjrt_plugin_name = Ref {String} (" pjrt_plugin_tt.so" )
810
911function __init__ ()
1012 @static if Sys. islinux ()
11- Reactant. precompiling () || setup_tt_pjrt_plugin! ()
13+ if ! Reactant. precompiling () && has_tt ()
14+ setup_tt_pjrt_plugin! ()
15+ end
1216 end
1317end
1418
1519has_tt () = true
1620
1721function setup_tt_pjrt_plugin! ()
18- path_from_env = get (ENV , " TT_LIBRARY_PATH " , nothing )
19- if path_from_env != = nothing && ispath (path_from_env )
20- tt_pjrt_plugin_dir[] = path_from_env
22+ plugin_dir_from_env = get (ENV , " TT_PJRT_PLUGIN_DIR " , nothing )
23+ if plugin_dir_from_env != = nothing && ispath (plugin_dir_from_env )
24+ tt_pjrt_plugin_dir[] = plugin_dir_from_env
2125 else
22- tt_pjrt_plugin_dir[] = @get_scratch! (" pjrt_tt_plugin " )
26+ tt_pjrt_plugin_dir[] = @get_scratch! (" pjrt_plugin_tt " )
2327 end
24- # download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[])
28+ download_tt_pjrt_plugin_if_needed (tt_pjrt_plugin_dir[])
2529 return nothing
2630end
2731
2832get_tt_pjrt_plugin_dir () = tt_pjrt_plugin_dir[]
2933
3034function get_tt_pjrt_plugin_path ()
31- return joinpath (get_tt_pjrt_plugin_dir (), " pjrt_plugin_tt.so " )
35+ return joinpath (get_tt_pjrt_plugin_dir (), tt_pjrt_plugin_name[] )
3236end
3337
34- # function download_tt_pjrt_plugin_if_needed(path=nothing)
35- # path === nothing && (path = get_tt_pjrt_plugin_dir())
36- # @assert path !== nothing "tt_pjrt_plugin_dir is not set!"
37-
38- # tt_pjrt_plugin_path = joinpath(path, "pjrt_plugin_tt_14.dylib")
39- # if !isfile(tt_pjrt_plugin_path)
40- # zip_file_path = joinpath(path, "pjrt-plugin-tt.zip")
41- # tmp_dir = joinpath(path, "tmp")
42- # Downloads.download(
43- # if Sys.ARCH === :aarch64
44- # "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_tt-0.1.1-py3-none-macosx_13_0_arm64.whl"
45- # elseif Sys.ARCH === :x86_64
46- # "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_tt-0.1.1-py3-none-macosx_10_14_x86_64.whl"
47- # else
48- # error("Unsupported architecture: $(Sys.ARCH)")
49- # end,
50- # zip_file_path,
51- # )
52- # run(`unzip -qq $(zip_file_path) -d $(tmp_dir)`)
53- # mv(
54- # joinpath(tmp_dir, "jax_plugins", "tt_plugin", "pjrt_plugin_tt_14.dylib"),
55- # tt_pjrt_plugin_path,
56- # )
57- # rm(tmp_dir; recursive=true)
58- # rm(zip_file_path; recursive=true)
59- # end
60- # end
38+ function download_tt_pjrt_plugin_if_needed (dir= nothing )
39+ dir === nothing && (dir = get_tt_pjrt_plugin_dir ())
40+ @assert dir != = nothing " tt_pjrt_plugin_dir is not set!"
41+
42+ tt_pjrt_plugin_path = joinpath (dir, tt_pjrt_plugin_name[])
43+ if isfile (tt_pjrt_plugin_path)
44+ @debug " TT PJRT plugin already found in '$(tt_pjrt_plugin_path) ', nothing to do"
45+ else
46+ @debug " Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path) '"
47+ mktempdir () do tmp_dir
48+ zip_file_path = joinpath (tmp_dir, " pjrt-plugin-tt.zip" )
49+ wheel_url = if Sys. ARCH === :x86_64
50+ " https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251113-cp311-cp311-linux_x86_64.whl"
51+ else
52+ error (" Unsupported architecture: $(Sys. ARCH) " )
53+ end
54+ @debug " Downloading TT PJRT plugin from '$(wheel_url) '"
55+ Downloads. download (wheel_url, zip_file_path)
56+ run (pipeline (` $(p7zip ()) x -tzip -o$(tmp_dir) -- $(zip_file_path) ` , devnull ))
57+ data_dir = only (filter! (endswith (" .data" ), readdir (tmp_dir; join= true )))
58+ # We need to move the entire `pjrt_plugin_tt` directory to the destination.
59+ mv (joinpath (data_dir, " purelib" , " pjrt_plugin_tt" ), dir; force= true )
60+ end
61+ @assert isfile (tt_pjrt_plugin_path)
62+ end
63+ end
6164
6265end # module TT
0 commit comments