@@ -273,26 +273,56 @@ function overloaded_mul!(
273273 return C
274274end
275275
276- function LinearAlgebra. triu! (@nospecialize (X:: TracedRArray{T,2} ), k:: Integer ) where {T}
276+ if isdefined (LinearAlgebra, :_triu )
277+ function LinearAlgebra. _triu (A:: AnyTracedRArray{T,2} , :: Val{true} , k:: Integer ) where {T}
278+ return overloaded_triu (materialize_traced_array (A), k)
279+ end
280+ function LinearAlgebra. _triu (
281+ A:: AnyTracedRArray{T,2} , :: Val{false} , k:: Integer
282+ ) where {T}
283+ return overloaded_triu (materialize_traced_array (A), k)
284+ end
285+ end
286+
287+ if isdefined (LinearAlgebra, :_tril )
288+ function LinearAlgebra. _tril (A:: AnyTracedRArray{T,2} , :: Val{true} , k:: Integer ) where {T}
289+ return overloaded_tril (materialize_traced_array (A), k)
290+ end
291+ function LinearAlgebra. _tril (
292+ A:: AnyTracedRArray{T,2} , :: Val{false} , k:: Integer
293+ ) where {T}
294+ return overloaded_tril (materialize_traced_array (A), k)
295+ end
296+ end
297+
298+ function LinearAlgebra. triu! (X:: AnyTracedRArray{T,2} , k:: Integer ) where {T}
299+ set_mlir_data! (X, get_mlir_data (overloaded_triu (materialize_traced_array (X), k)))
300+ return X
301+ end
302+
303+ function LinearAlgebra. tril! (X:: AnyTracedRArray{T,2} , k:: Integer ) where {T}
304+ set_mlir_data! (X, get_mlir_data (overloaded_tril (materialize_traced_array (X), k)))
305+ return X
306+ end
307+
308+ function overloaded_triu (X:: TracedRArray{T,2} , k:: Integer ) where {T}
277309 iota_1 = @opcall iota (Int64, [size (X)... ]; iota_dimension= 1 )
278310 iota_2 = @opcall subtract (
279311 @opcall (iota (Int64, [size (X)... ]; iota_dimension= 2 )),
280312 Reactant. broadcast_to_size (k, size (X)),
281313 )
282314 idxs = @opcall compare (iota_1, iota_2; comparison_direction= " LE" )
283- X. mlir_data = @opcall (select (idxs, X, zero (X))). mlir_data
284- return X
315+ return @opcall select (idxs, X, zero (X))
285316end
286317
287- function LinearAlgebra . tril! ( @nospecialize ( X:: TracedRArray{T,2} ) , k:: Integer ) where {T}
318+ function overloaded_tril ( X:: TracedRArray{T,2} , k:: Integer ) where {T}
288319 iota_1 = @opcall iota (Int64, [size (X)... ]; iota_dimension= 1 )
289320 iota_2 = @opcall subtract (
290321 @opcall (iota (Int64, [size (X)... ]; iota_dimension= 2 )),
291322 Reactant. broadcast_to_size (k, size (X)),
292323 )
293324 idxs = @opcall compare (iota_1, iota_2; comparison_direction= " GE" )
294- X. mlir_data = @opcall (select (idxs, X, zero (X))). mlir_data
295- return X
325+ return @opcall select (idxs, X, zero (X))
296326end
297327
298328# LinearAlgebra defines norm with some conditionals which cannot be traced directly
0 commit comments