@@ -3,49 +3,71 @@ module LinearOperatorsChainRulesCoreExt
33using LinearOperators
44isdefined (Base, :get_extension ) ? (import ChainRulesCore) : (import .. ChainRulesCore)
55
6- function ChainRulesCore. frule ((_, Δx, _), :: typeof (* ), op:: AbstractLinearOperator{T} , x:: AbstractVector{S} ) where {T, S}
7- y = op* x
8- Δy = op* Δx
6+ function ChainRulesCore. frule (
7+ (_, Δx, _),
8+ :: typeof (* ),
9+ op:: AbstractLinearOperator{T} ,
10+ x:: AbstractVector{S} ,
11+ ) where {T, S}
12+ y = op * x
13+ Δy = op * Δx
914 return y, Δy
1015end
11- function ChainRulesCore. rrule (:: typeof (* ), op:: AbstractLinearOperator{T} , x:: AbstractVector{S} ) where {T, S}
12- y = op* x
16+ function ChainRulesCore. rrule (
17+ :: typeof (* ),
18+ op:: AbstractLinearOperator{T} ,
19+ x:: AbstractVector{S} ,
20+ ) where {T, S}
21+ y = op * x
1322 project_x = ChainRulesCore. ProjectTo (x)
1423 function mul_pullback (ȳ)
15- x̄ = project_x ( adjoint (op)* ChainRulesCore. unthunk (ȳ) )
16- return ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent (), x̄
24+ x̄ = project_x (adjoint (op) * ChainRulesCore. unthunk (ȳ))
25+ return ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent (), x̄
1726 end
1827 return y, mul_pullback
1928end
2029
21- function ChainRulesCore. frule ((_, Δx, _), :: typeof (* ), x:: Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V} } , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
22- y = x* op
23- Δy = Δx* op
30+ function ChainRulesCore. frule (
31+ (_, Δx, _),
32+ :: typeof (* ),
33+ x:: Union{LinearOperators.Adjoint{S, V}, LinearOperators.Transpose{S, V}} ,
34+ op:: AbstractLinearOperator{T} ,
35+ ) where {T, S, V <: AbstractVector{S} }
36+ y = x * op
37+ Δy = Δx * op
2438 return y, Δy
2539end
26- function ChainRulesCore. rrule (:: typeof (* ), x:: LinearOperators.Transpose{S, V} , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
27- y = x* op
40+ function ChainRulesCore. rrule (
41+ :: typeof (* ),
42+ x:: LinearOperators.Transpose{S, V} ,
43+ op:: AbstractLinearOperator{T} ,
44+ ) where {T, S, V <: AbstractVector{S} }
45+ y = x * op
2846 project_x = ChainRulesCore. ProjectTo (x)
2947 function mul_pullback (ȳ)
30- # needed to make sure that ȳ is recognized as Transposed
31- # ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
32- ȳ_ = transpose (vec (ChainRulesCore. unthunk (ȳ)))
33- x̄ = project_x (ȳ_* adjoint (op))
34- return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
48+ # needed to make sure that ȳ is recognized as Transposed
49+ # ȳ_ = transpose(collect(vec(ChainRulesCore.unthunk(ȳ))))
50+ ȳ_ = transpose (vec (ChainRulesCore. unthunk (ȳ)))
51+ x̄ = project_x (ȳ_ * adjoint (op))
52+ return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
3553 end
3654 return y, mul_pullback
3755end
38- function ChainRulesCore. rrule (:: typeof (* ), x:: LinearOperators.Adjoint{S, V} , op:: AbstractLinearOperator{T} ) where {T, S, V <: AbstractVector{S} }
39- y = x* op
56+ function ChainRulesCore. rrule (
57+ :: typeof (* ),
58+ x:: LinearOperators.Adjoint{S, V} ,
59+ op:: AbstractLinearOperator{T} ,
60+ ) where {T, S, V <: AbstractVector{S} }
61+ y = x * op
4062 project_x = ChainRulesCore. ProjectTo (x)
4163 function mul_pullback (ȳ)
42- # needed to make sure that ȳ is recognized as Adjoint
43- # ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
44- ȳ_ = adjoint (conj .(vec (ChainRulesCore. unthunk (ȳ))))
45- x̄ = project_x (ȳ_* adjoint (op))
46- return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
64+ # needed to make sure that ȳ is recognized as Adjoint
65+ # ȳ_ = adjoint(collect(vec(ChainRulesCore.unthunk(ȳ))))
66+ ȳ_ = adjoint (conj .(vec (ChainRulesCore. unthunk (ȳ))))
67+ x̄ = project_x (ȳ_ * adjoint (op))
68+ return ChainRulesCore. NoTangent (), x̄, ChainRulesCore. NoTangent ()
4769 end
4870 return y, mul_pullback
4971end
5072
51- end # module
73+ end # module
0 commit comments