Skip to content

Commit f5e61ff

Browse files
authored
feat: more 1.12 support (#1796)
* fix: add _accumulate_promote_op * test: disable Zygote on 1.12 * fix: tril/triu working again * fix: inplace versions
1 parent 5462657 commit f5e61ff

File tree

3 files changed

+54
-7
lines changed

3 files changed

+54
-7
lines changed

src/TracedRArray.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,19 @@ function Base.accumulate_pairwise!(op, A::AnyTracedRVector, B::AnyTracedRVector)
11091109
return accumulate!(op, A, B; dims=1)
11101110
end
11111111

1112+
if isdefined(Base, :_accumulate_promote_op)
1113+
function Base._accumulate_promote_op(op, A::AnyTracedRArray{T}; init=nothing) where {T}
1114+
if init !== nothing
1115+
init isa TracedRNumber && (init = zero(unwrapped_eltype(init)))
1116+
end
1117+
return TracedRNumber{
1118+
unwrapped_eltype(
1119+
Base._accumulate_promote_op(op, Array{T,ndims(A)}(undef, size(A)); init)
1120+
),
1121+
}
1122+
end
1123+
end
1124+
11121125
function Base._accumulate!(
11131126
op, output::AnyTracedRArray, input::AnyTracedRVector, ::Nothing, ::Nothing
11141127
)

src/stdlibs/LinearAlgebra.jl

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,26 +273,56 @@ function overloaded_mul!(
273273
return C
274274
end
275275

276-
function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
276+
if isdefined(LinearAlgebra, :_triu)
277+
function LinearAlgebra._triu(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
278+
return overloaded_triu(materialize_traced_array(A), k)
279+
end
280+
function LinearAlgebra._triu(
281+
A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer
282+
) where {T}
283+
return overloaded_triu(materialize_traced_array(A), k)
284+
end
285+
end
286+
287+
if isdefined(LinearAlgebra, :_tril)
288+
function LinearAlgebra._tril(A::AnyTracedRArray{T,2}, ::Val{true}, k::Integer) where {T}
289+
return overloaded_tril(materialize_traced_array(A), k)
290+
end
291+
function LinearAlgebra._tril(
292+
A::AnyTracedRArray{T,2}, ::Val{false}, k::Integer
293+
) where {T}
294+
return overloaded_tril(materialize_traced_array(A), k)
295+
end
296+
end
297+
298+
function LinearAlgebra.triu!(X::AnyTracedRArray{T,2}, k::Integer) where {T}
299+
set_mlir_data!(X, get_mlir_data(overloaded_triu(materialize_traced_array(X), k)))
300+
return X
301+
end
302+
303+
function LinearAlgebra.tril!(X::AnyTracedRArray{T,2}, k::Integer) where {T}
304+
set_mlir_data!(X, get_mlir_data(overloaded_tril(materialize_traced_array(X), k)))
305+
return X
306+
end
307+
308+
function overloaded_triu(X::TracedRArray{T,2}, k::Integer) where {T}
277309
iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1)
278310
iota_2 = @opcall subtract(
279311
@opcall(iota(Int64, [size(X)...]; iota_dimension=2)),
280312
Reactant.broadcast_to_size(k, size(X)),
281313
)
282314
idxs = @opcall compare(iota_1, iota_2; comparison_direction="LE")
283-
X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data
284-
return X
315+
return @opcall select(idxs, X, zero(X))
285316
end
286317

287-
function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
318+
function overloaded_tril(X::TracedRArray{T,2}, k::Integer) where {T}
288319
iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1)
289320
iota_2 = @opcall subtract(
290321
@opcall(iota(Int64, [size(X)...]; iota_dimension=2)),
291322
Reactant.broadcast_to_size(k, size(X)),
292323
)
293324
idxs = @opcall compare(iota_1, iota_2; comparison_direction="GE")
294-
X.mlir_data = @opcall(select(idxs, X, zero(X))).mlir_data
295-
return X
325+
return @opcall select(idxs, X, zero(X))
296326
end
297327

298328
# LinearAlgebra defines norm with some conditionals which cannot be traced directly

test/runtests.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,16 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5353
@safetestset "Python" include("integration/python.jl")
5454
@safetestset "Optimisers" include("integration/optimisers.jl")
5555
@safetestset "FillArrays" include("integration/fillarrays.jl")
56-
@safetestset "Zygote" include("integration/zygote.jl")
5756
@safetestset "MPI" begin
5857
using MPI
5958
nranks = 2
6059
run(`$(mpiexec()) -n $nranks $(Base.julia_cmd()) integration/mpi.jl`)
6160
end
61+
62+
# Zygote is not supported on 1.12 https://github.com/FluxML/Zygote.jl/issues/1580
63+
if VERSION < v"1.12-"
64+
@safetestset "Zygote" include("integration/zygote.jl")
65+
end
6266
end
6367

6468
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"

0 commit comments

Comments
 (0)