Skip to content

Commit f498044

Browse files
authored
feat: new jll version + new compiler passes (#1791)
* feat: new jll version + new compiler passes * fix: use newer libtpu builds * fix: try always downloading libtpu on ci * feat: version check for libtpu
1 parent 54f5d5d commit f498044

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>", "Avik Pal <avikpal@mit.edu>", "Mosè Giordano <mose@gnu.org>"]
4-
version = "0.2.171"
4+
version = "0.2.172"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.253"
108+
Reactant_jll = "0.0.254"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

src/Compiler.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,7 @@ function optimization_passes(
910910
"remove_no_ops_from_while_loop",
911911
"while_is_copy_simplify",
912912
"split_variadic_scatter_op",
913+
"dynamic_slice_simplify",
913914
]
914915

915916
if !compile_options.disable_auto_batching_passes
@@ -1124,9 +1125,11 @@ function optimization_passes(
11241125
if AGGRESSIVE_PROPAGATION[]
11251126
push!(transform_passes_list, "reshape_slice(0)")
11261127
push!(transform_passes_list, "reshape_elementwise(0)")
1128+
push!(transform_passes_list, "reshape_dynamic_slice(0)")
11271129
else
11281130
push!(transform_passes_list, "reshape_slice(1)")
11291131
push!(transform_passes_list, "reshape_elementwise(1)")
1132+
push!(transform_passes_list, "reshape_dynamic_slice(1)")
11301133
end
11311134
elseif compile_options.reshape_propagate === :down
11321135
append!(

src/accelerators/TPU.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ using unzip_jll: unzip
1010
const libtpu_dir = Ref{Union{Nothing,String}}(nothing)
1111
const RUNNING_IN_CLOUD_TPU_VM = Ref(false)
1212

13+
const LIBTPU_VERSION = "0.0.28.dev20251027"
14+
const LIBTPU_SO = "libtpu-$(replace(string(LIBTPU_VERSION), '.' => '_')).so"
15+
1316
function __init__()
1417
@static if !Sys.isapple()
1518
if !Reactant.precompiling() && has_tpu()
@@ -32,18 +35,18 @@ end
3235

3336
get_libtpu_dir() = libtpu_dir[]
3437

35-
get_libtpu_path() = joinpath(get_libtpu_dir(), "libtpu.so")
38+
get_libtpu_path() = joinpath(get_libtpu_dir(), LIBTPU_SO)
3639

3740
function download_libtpu_if_needed(path=nothing)
3841
path === nothing && (path = get_libtpu_dir())
3942
@assert path !== nothing "libtpu_dir is not set!"
4043

41-
libtpu_path = joinpath(path, "libtpu.so")
44+
libtpu_path = joinpath(path, LIBTPU_SO)
4245
if !isfile(libtpu_path)
4346
zip_file_path = joinpath(path, "tpu.zip")
4447
tmp_dir = joinpath(path, "tmp")
4548
Downloads.download(
46-
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250811+nightly-py3-none-manylinux_2_31_x86_64.whl",
49+
"https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu/libtpu-0.0.28.dev20251027+nightly-cp314-cp314t-manylinux_2_31_x86_64.whl",
4750
zip_file_path,
4851
)
4952
run(`$(unzip()) -qq $(zip_file_path) -d $(tmp_dir)`)

test/ops.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,12 +791,12 @@ end
791791

792792
@testset "acos" begin
793793
x = Reactant.to_rarray(Float32[-1.0, 0.0, 1.0])
794-
@test acos.(Array(x)) @jit(Ops.acos(x)) broken = RunningOnTPU
794+
@test acos.(Array(x)) @jit(Ops.acos(x))
795795
end
796796

797797
@testset "acosh" begin
798798
x = Reactant.to_rarray(Float32[1.0, 10.0])
799-
@test acosh.(Array(x)) @jit(Ops.acosh(x)) broken = RunningOnTPU
799+
@test acosh.(Array(x)) @jit(Ops.acosh(x))
800800
end
801801

802802
@testset "asin" begin

0 commit comments

Comments
 (0)