@@ -343,28 +343,71 @@ function rrule(::typeof(\), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:R
343343
344344 Y = A \ B
345345
346- Atf = factorize (A' )
347-
348346 function backslash_pullback (ȳ)
349347 Ȳ = unthunk (ȳ)
348+
349+ Ȳf = Ȳ
350350 @static if VERSION >= v " 1.9"
351351 # Need to ensure Ȳ is an array since since https://github.com/JuliaLang/julia/pull/44358
352- Ȳ isa AbstractArray || Ȳ = [Ȳ]
352+ if ! isa (Ȳ, AbstractArray)
353+ Ȳf = [Ȳ]
354+ end
355+ end
356+ Yf = Y
357+ @static if VERSION >= v " 1.9"
358+ # Need to ensure Yf is an array since since https://github.com/JuliaLang/julia/pull/44358
359+ if ! isa (Y, AbstractArray)
360+ Yf = [Y]
361+ end
353362 end
354- Atf = factorize (A ' )
363+ # @info "vars" typeof(Ȳ) typeof(Y) typeof(Yf) typeof(A) typeof(B )
355364 ∂A = @thunk begin
356- B̄ = Atf \ Ȳ
365+ B̄ = A ' \ Ȳf
357366 Ā = - B̄ * Y'
358- Ā = add!! (Ā, ((B - A * Y) * B̄' ) / Atf)
359- Ā = add!! (Ā, Atf \ Y * (Ȳ' - B̄' A))
367+ t = (B - A * Y) * B̄'
368+ @static if VERSION >= v " 1.9"
369+ # Need to ensure t is an array since since https://github.com/JuliaLang/julia/pull/44358
370+ if ! isa (t, AbstractArray)
371+ t = [t]
372+ end
373+ end
374+ Ā = add!! (Ā, t / A' )
375+ Ā = add!! (Ā, A' \ Yf * (Ȳ' - B̄' A))
360376 project_A (Ā)
361377 end
362- ∂B = @thunk project_B (Atf \ Ȳ )
378+ ∂B = @thunk project_B (A ' \ Ȳf )
363379 return NoTangent (), ∂A, ∂B
364380 end
365381 return Y, backslash_pullback
366382end
367383
384+ @static if VERSION >= v " 1.9"
385+ # Need to ensure things are not scalar since since https://github.com/JuliaLang/julia/pull/44358
386+ _maybe_descalar (x) = x isa AbstractArray ? x : [x]
387+ else
388+ _maybe_descalar (x) = x
389+ end
390+
391+ function rrule (A:: AbstractVecOrMat{<:Real} , B:: AbstractVecOrMat{<:Real} )
392+ Y = A \ B
393+
394+
395+ function backslash_pullback (ȳ)
396+ Ȳ = unthunk (ȳ)
397+
398+ ∂A = @thunk begin
399+ B̄ = A' \ _maybe_descalar (Ȳ)
400+ Ā = - B̄ * Y'
401+ Ā += _maybe_descalar ((B - A * Y) * B̄' ) / A'
402+ Ā += (A' \ _maybe_descalar (Y)) * (Ȳ' - B̄' A)
403+ (Ā)
404+ end
405+ ∂B = @thunk (A' \ _maybe_descalar (Ȳ))
406+ return ∂A, ∂B
407+ end
408+ return Y, backslash_pullback
409+ end
410+
368411# ####
369412# #### `\`, `/` matrix-scalar_rule
370413# ####
0 commit comments