@@ -4,27 +4,95 @@ ChainRulesCore.@scalar_rule +(x::APL) true
44ChainRulesCore. @scalar_rule - (x:: APL ) - 1
55
66ChainRulesCore. @scalar_rule + (x:: APL , y:: APL ) (true , true )
7+ function plusconstant1_pullback (Δ)
8+ return ChainRulesCore. NoTangent (), Δ, coefficient (Δ, constantmonomial (Δ))
9+ end
10+ function ChainRulesCore. rrule (:: typeof (plusconstant), p:: APL , α)
11+ return plusconstant (p, α), plusconstant1_pullback
12+ end
13+ function plusconstant2_pullback (Δ)
14+ return ChainRulesCore. NoTangent (), coefficient (Δ, constantmonomial (Δ)), Δ
15+ end
16+ function ChainRulesCore. rrule (:: typeof (plusconstant), α, p:: APL )
17+ return plusconstant (α, p), plusconstant2_pullback
18+ end
719ChainRulesCore. @scalar_rule - (x:: APL , y:: APL ) (true , - 1 )
820
921function ChainRulesCore. frule ((_, Δp, Δq), :: typeof (* ), p:: APL , q:: APL )
1022 return p * q, MA. add_mul!! (p * Δq, q, Δp)
1123end
24+
25+ function _adjoint_mult (op:: F , ts, p, Δ) where {F<: Function }
26+ for t in terms (p)
27+ c = coefficient (t)
28+ m = monomial (t)
29+ for δ in Δ
30+ if divides (m, δ)
31+ coef = op (c, coefficient (δ))
32+ mono = _div (monomial (δ), m)
33+ push! (ts, term (coef, mono))
34+ end
35+ end
36+ end
37+ return polynomial (ts)
38+ end
39+ function adjoint_mult_left (p, Δ)
40+ ts = MA. promote_operation (* , MA. promote_operation (adjoint, termtype (p)), termtype (Δ))[]
41+ return _adjoint_mult (ts, p, Δ) do c, d
42+ c' * d
43+ end
44+ end
45+ function adjoint_mult_right (p, Δ)
46+ ts = MA. promote_operation (* , termtype (Δ), MA. promote_operation (adjoint, termtype (p)))[]
47+ return _adjoint_mult (ts, p, Δ) do c, d
48+ d * c'
49+ end
50+ end
51+
1252function ChainRulesCore. rrule (:: typeof (* ), p:: APL , q:: APL )
1353 function times_pullback2 (ΔΩ̇)
14- # ΔΩ = ChainRulesCore.unthunk(Ω̇)
15- # return (ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(p)(ΔΩ * q'), ChainRulesCore.ProjectTo(q)(p' * ΔΩ))
16- return (ChainRulesCore. NoTangent (), ΔΩ̇ * q' , p' * ΔΩ̇)
54+ return (ChainRulesCore. NoTangent (), adjoint_mult_right (q, ΔΩ̇), adjoint_mult_left (p, ΔΩ̇))
1755 end
1856 return p * q, times_pullback2
1957end
2058
59+ function ChainRulesCore. rrule (:: typeof (multconstant), α, p:: APL )
60+ function times_pullback2 (ΔΩ̇)
61+ # TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
62+ Δα = adjoint_mult_right (p, ΔΩ̇)
63+ return (ChainRulesCore. NoTangent (), coefficient (Δα, constantmonomial (Δα)), α' * ΔΩ̇)
64+ end
65+ return multconstant (α, p), times_pullback2
66+ end
67+
68+ function ChainRulesCore. rrule (:: typeof (multconstant), p:: APL , α)
69+ function times_pullback2 (ΔΩ̇)
70+ # TODO we could make it faster, don't need to compute `Δα` entirely if we only care about the constant term.
71+ Δα = adjoint_mult_left (p, ΔΩ̇)
72+ return (ChainRulesCore. NoTangent (), ΔΩ̇ * α' , coefficient (Δα, constantmonomial (Δα)))
73+ end
74+ return multconstant (p, α), times_pullback2
75+ end
76+
77+ notangent3 (Δ) = ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent (), ChainRulesCore. NoTangent ()
78+ function ChainRulesCore. rrule (:: typeof (^ ), mono:: AbstractMonomialLike , i:: Integer )
79+ return mono^ i, notangent3
80+ end
81+
2182function ChainRulesCore. frule ((_, Δp, _), :: typeof (differentiate), p, x)
2283 return differentiate (p, x), differentiate (Δp, x)
2384end
24- function pullback (Δdpdx, x)
85+ function differentiate_pullback (Δdpdx, x)
2586 return ChainRulesCore. NoTangent (), x * differentiate (x * Δdpdx, x), ChainRulesCore. NoTangent ()
2687end
2788function ChainRulesCore. rrule (:: typeof (differentiate), p, x)
2889 dpdx = differentiate (p, x)
29- return dpdx, Base. Fix2 (pullback, x)
90+ return dpdx, Base. Fix2 (differentiate_pullback, x)
91+ end
92+
93+ function coefficient_pullback (Δ, m:: AbstractMonomialLike )
94+ return ChainRulesCore. NoTangent (), polynomial (term (Δ, m)), ChainRulesCore. NoTangent ()
95+ end
96+ function ChainRulesCore. rrule (:: typeof (coefficient), p:: APL , m:: AbstractMonomialLike )
97+ return coefficient (p, m), Base. Fix2 (coefficient_pullback, m)
3098end
0 commit comments