1+ module LinearOperatorsChainRulesCoreExt
2+
3+ using LinearOperators
4+ import ChainRulesCore
5+
6+ function ChainRulesCore. frule ((_, Δx, _), :: typeof (* ), op:: AbstractLinearOperator{T} , x:: AbstractVector{S} ) where {T, S}
7+ y = op* x
8+ Δy = op* Δx
9+ return y, Δy
10+ end
11+ function ChainRulesCore. rrule (:: typeof (* ), op:: AbstractLinearOperator{T} , x:: AbstractVector{S} ) where {T, S}
12+ y = op* x
13+ project_x = ChainRulesCore. ProjectTo (x)
14+ function mul_pullback (ȳ)
15+ x̄ = project_x ( adjoint (op)* ChainRulesCore. unthunk (ȳ) )
16+ return ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent (), x̄
17+ end
18+ return y, mul_pullback
19+ end
20+
21+ function ChainRulesCore. frule ((_, Δx, _), :: typeof (* ), x:: Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V} } , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
22+ y = x* op
23+ Δy = Δx* op
24+ return y, Δy
25+ end
26+ function ChainRulesCore. rrule (:: typeof (* ), x:: LinearOperators.Transpose{S, V} , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
27+ y = x* op
28+ project_x = ChainRulesCore. ProjectTo (x)
29+ function mul_pullback (ȳ)
30+ # needed to make sure that ȳ is recognized as Transposed
31+ # ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
32+ ȳ_ = transpose (vec (ChainRulesCore. unthunk (ȳ)))
33+ x̄ = project_x (ȳ_* adjoint (op))
34+ return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
35+ end
36+ return y, mul_pullback
37+ end
38+ function ChainRulesCore. rrule (:: typeof (* ), x:: LinearOperators.Adjoint{S, V} , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
39+ y = x* op
40+ project_x = ChainRulesCore. ProjectTo (x)
41+ function mul_pullback (ȳ)
42+ # needed to make sure that ȳ is recognized as Adjoint
43+ # ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
44+ ȳ_ = adjoint (conj .(vec (ChainRulesCore. unthunk (ȳ))))
45+ x̄ = project_x (ȳ_* adjoint (op))
46+ return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
47+ end
48+ return y, mul_pullback
49+ end
50+
51+ end # module
0 commit comments