11# # Docstrings
22
33"""
4- prepare_hvp(f, backend, x, v ) -> extras
4+ prepare_hvp(f, backend, x, dx ) -> extras
55
66Create an `extras` object that can be given to [`hvp`](@ref) and its variants.
77
@@ -11,7 +11,7 @@ Create an `extras` object that can be given to [`hvp`](@ref) and its variants.
1111function prepare_hvp end
1212
1313"""
14- prepare_hvp_same_point(f, backend, x, v ) -> extras_same
14+ prepare_hvp_same_point(f, backend, x, dx ) -> extras_same
1515
1616Create an `extras_same` object that can be given to [`hvp`](@ref) and its variants _if they are applied at the same point `x`_.
1717
@@ -21,16 +21,16 @@ Create an `extras_same` object that can be given to [`hvp`](@ref) and its varian
2121function prepare_hvp_same_point end
2222
2323"""
24- hvp(f, backend, x, v , [extras]) -> p
24+ hvp(f, backend, x, dx , [extras]) -> p
2525
26- Compute the Hessian-vector product of `f` at point `x` with seed `v `.
26+ Compute the Hessian-vector product of `f` at point `x` with seed `dx `.
2727"""
2828function hvp end
2929
3030"""
31- hvp!(f, p, backend, x, v , [extras]) -> p
31+ hvp!(f, p, backend, x, dx , [extras]) -> p
3232
33- Compute the Hessian-vector product of `f` at point `x` with seed `v `, overwriting `p`.
33+ Compute the Hessian-vector product of `f` at point `x` with seed `dx `, overwriting `p`.
3434"""
3535function hvp! end
3636
@@ -45,181 +45,168 @@ abstract type HVPExtras <: Extras end
4545
4646struct NoHVPExtras <: HVPExtras end
4747
48- #=
49- Source: https://arxiv.org/abs/2403.14606 (section 8.1)
48+ struct InnerGradient{F,B}
49+ f:: F
50+ backend:: B
51+ end
52+
53+ function (ig:: InnerGradient )(x)
54+ @compat (; f, backend) = ig
55+ return gradient (f, backend, x)
56+ end
57+
58+ struct InnerPushforwardFixedSeed{F,B,DX}
59+ f:: F
60+ backend:: B
61+ dx:: DX
62+ end
5063
51- By order of preference:
52- - forward on reverse
53- - reverse on forward
54- - reverse on reverse
55- - forward on forward
56- =#
64+ function (ipfs:: InnerPushforwardFixedSeed )(x)
65+ @compat (; f, backend, dx) = ipfs
66+ return pushforward (f, backend, x, dx)
67+ end
5768
58- struct ForwardOverForwardHVPExtras{C,E } <: HVPExtras
59- inner_gradient_closure :: C
69+ struct ForwardOverForwardHVPExtras{IG <: InnerGradient ,E <: PushforwardExtras } <: HVPExtras
70+ inner_gradient :: IG
6071 outer_pushforward_extras:: E
6172end
6273
63- struct ForwardOverReverseHVPExtras{C,E } <: HVPExtras
64- inner_gradient_closure :: C
74+ struct ForwardOverReverseHVPExtras{IG <: InnerGradient ,E <: PushforwardExtras } <: HVPExtras
75+ inner_gradient :: IG
6576 outer_pushforward_extras:: E
6677end
6778
68- struct ReverseOverForwardHVPExtras{C,E} <: HVPExtras
69- inner_pushforward_closure_generator:: C
79+ struct ReverseOverForwardHVPExtras{E<: GradientExtras } <: HVPExtras
7080 outer_gradient_extras:: E
7181end
7282
73- struct ReverseOverReverseHVPExtras{C,E } <: HVPExtras
74- inner_gradient_closure :: C
83+ struct ReverseOverReverseHVPExtras{IG <: InnerGradient ,E <: PullbackExtras } <: HVPExtras
84+ inner_gradient :: IG
7585 outer_pullback_extras:: E
7686end
7787
78- function prepare_hvp (f:: F , backend:: AbstractADType , x, v ) where {F}
79- return prepare_hvp (f, SecondOrder (backend, backend), x, v )
88+ function prepare_hvp (f:: F , backend:: AbstractADType , x, dx ) where {F}
89+ return prepare_hvp (f, SecondOrder (backend, backend), x, dx )
8090end
8191
82- function prepare_hvp (f:: F , backend:: SecondOrder , x, v ) where {F}
83- return prepare_hvp (f, backend, x, v , hvp_mode (backend))
92+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx ) where {F}
93+ return prepare_hvp (f, backend, x, dx , hvp_mode (backend))
8494end
8595
86- function prepare_hvp (f:: F , backend:: SecondOrder , x, v , :: ForwardOverForward ) where {F}
96+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx , :: ForwardOverForward ) where {F}
8797 # pushforward of many pushforwards in theory, but pushforward of gradient in practice
88- inner_backend = nested (inner (backend))
89- inner_gradient_closure (z) = gradient (f, inner_backend, z)
90- outer_pushforward_extras = prepare_pushforward (
91- inner_gradient_closure, outer (backend), x, v
92- )
93- return ForwardOverForwardHVPExtras (inner_gradient_closure, outer_pushforward_extras)
98+ inner_gradient = InnerGradient (f, nested (inner (backend)))
99+ outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, dx)
100+ return ForwardOverForwardHVPExtras (inner_gradient, outer_pushforward_extras)
94101end
95102
96- function prepare_hvp (f:: F , backend:: SecondOrder , x, v , :: ForwardOverReverse ) where {F}
103+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx , :: ForwardOverReverse ) where {F}
97104 # pushforward of gradient
98- inner_backend = nested (inner (backend))
99- inner_gradient_closure (z) = gradient (f, inner_backend, z)
100- outer_pushforward_extras = prepare_pushforward (
101- inner_gradient_closure, outer (backend), x, v
102- )
103- return ForwardOverReverseHVPExtras (inner_gradient_closure, outer_pushforward_extras)
105+ inner_gradient = InnerGradient (f, nested (inner (backend)))
106+ outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, dx)
107+ return ForwardOverReverseHVPExtras (inner_gradient, outer_pushforward_extras)
104108end
105109
106- function prepare_hvp (f:: F , backend:: SecondOrder , x, v , :: ReverseOverForward ) where {F}
110+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx , :: ReverseOverForward ) where {F}
107111 # gradient of pushforward
108- # uses v in the closure
109- inner_backend = nested (inner (backend))
110- function inner_pushforward_closure_generator (v)
111- inner_pushforward_closure (z) = pushforward (f, inner_backend, z, v)
112- return inner_pushforward_closure
113- end
114- outer_gradient_extras = prepare_gradient (
115- inner_pushforward_closure_generator (v), outer (backend), x
116- )
117- return ReverseOverForwardHVPExtras (
118- inner_pushforward_closure_generator, outer_gradient_extras
119- )
120- end
121-
122- function prepare_hvp (f:: F , backend:: SecondOrder , x, v, :: ReverseOverReverse ) where {F}
112+ # uses dx in the closure so it can't be stored
113+ inner_pushforward = InnerPushforwardFixedSeed (f, nested (inner (backend)), dx)
114+ outer_gradient_extras = prepare_gradient (inner_pushforward, outer (backend), x)
115+ return ReverseOverForwardHVPExtras (outer_gradient_extras)
116+ end
117+
118+ function prepare_hvp (f:: F , backend:: SecondOrder , x, dx, :: ReverseOverReverse ) where {F}
123119 # pullback of the gradient
124- inner_backend = nested (inner (backend))
125- inner_gradient_closure (z) = gradient (f, inner_backend, z)
126- outer_pullback_extras = prepare_pullback (inner_gradient_closure, outer (backend), x, v)
127- return ReverseOverReverseHVPExtras (inner_gradient_closure, outer_pullback_extras)
120+ inner_gradient = InnerGradient (f, nested (inner (backend)))
121+ outer_pullback_extras = prepare_pullback (inner_gradient, outer (backend), x, dx)
122+ return ReverseOverReverseHVPExtras (inner_gradient, outer_pullback_extras)
128123end
129124
130125# # Preparation (same point)
131126
132127function prepare_hvp_same_point (
133- f:: F , backend:: AbstractADType , x, v , extras:: HVPExtras
128+ f:: F , backend:: AbstractADType , x, dx , extras:: HVPExtras
134129) where {F}
135130 return extras
136131end
137132
138- function prepare_hvp_same_point (f:: F , backend:: AbstractADType , x, v ) where {F}
139- extras = prepare_hvp (f, backend, x, v )
140- return prepare_hvp_same_point (f, backend, x, v , extras)
133+ function prepare_hvp_same_point (f:: F , backend:: AbstractADType , x, dx ) where {F}
134+ extras = prepare_hvp (f, backend, x, dx )
135+ return prepare_hvp_same_point (f, backend, x, dx , extras)
141136end
142137
143138# # One argument
144139
145- function hvp (f:: F , backend:: AbstractADType , x, v ) where {F}
146- return hvp (f, backend, x, v , prepare_hvp (f, backend, x, v ))
140+ function hvp (f:: F , backend:: AbstractADType , x, dx ) where {F}
141+ return hvp (f, backend, x, dx , prepare_hvp (f, backend, x, dx ))
147142end
148143
149- function hvp! (f:: F , p, backend:: AbstractADType , x, v ) where {F}
150- return hvp! (f, p, backend, x, v , prepare_hvp (f, backend, x, v ))
144+ function hvp! (f:: F , p, backend:: AbstractADType , x, dx ) where {F}
145+ return hvp! (f, p, backend, x, dx , prepare_hvp (f, backend, x, dx ))
151146end
152147
153- function hvp (f:: F , backend:: AbstractADType , x, v , extras:: HVPExtras ) where {F}
154- return hvp (f, SecondOrder (backend, backend), x, v , extras)
148+ function hvp (f:: F , backend:: AbstractADType , x, dx , extras:: HVPExtras ) where {F}
149+ return hvp (f, SecondOrder (backend, backend), x, dx , extras)
155150end
156151
157152function hvp (
158- f:: F , backend:: SecondOrder , x, v , extras:: ForwardOverForwardHVPExtras
153+ f:: F , backend:: SecondOrder , x, dx , extras:: ForwardOverForwardHVPExtras
159154) where {F}
160- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
161- return pushforward (
162- inner_gradient_closure, outer (backend), x, v, outer_pushforward_extras
163- )
155+ @compat (; inner_gradient, outer_pushforward_extras) = extras
156+ return pushforward (inner_gradient, outer (backend), x, dx, outer_pushforward_extras)
164157end
165158
166159function hvp (
167- f:: F , backend:: SecondOrder , x, v , extras:: ForwardOverReverseHVPExtras
160+ f:: F , backend:: SecondOrder , x, dx , extras:: ForwardOverReverseHVPExtras
168161) where {F}
169- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
170- return pushforward (
171- inner_gradient_closure, outer (backend), x, v, outer_pushforward_extras
172- )
162+ @compat (; inner_gradient, outer_pushforward_extras) = extras
163+ return pushforward (inner_gradient, outer (backend), x, dx, outer_pushforward_extras)
173164end
174165
175166function hvp (
176- f:: F , backend:: SecondOrder , x, v , extras:: ReverseOverForwardHVPExtras
167+ f:: F , backend:: SecondOrder , x, dx , extras:: ReverseOverForwardHVPExtras
177168) where {F}
178- @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras
179- inner_pushforward_closure = inner_pushforward_closure_generator (v )
180- return gradient (inner_pushforward_closure , outer (backend), x, outer_gradient_extras)
169+ @compat (; outer_gradient_extras) = extras
170+ inner_pushforward = InnerPushforwardFixedSeed (f, nested ( inner (backend)), dx )
171+ return gradient (inner_pushforward , outer (backend), x, outer_gradient_extras)
181172end
182173
183174function hvp (
184- f:: F , backend:: SecondOrder , x, v , extras:: ReverseOverReverseHVPExtras
175+ f:: F , backend:: SecondOrder , x, dx , extras:: ReverseOverReverseHVPExtras
185176) where {F}
186- @compat (; inner_gradient_closure , outer_pullback_extras) = extras
187- return pullback (inner_gradient_closure , outer (backend), x, v , outer_pullback_extras)
177+ @compat (; inner_gradient , outer_pullback_extras) = extras
178+ return pullback (inner_gradient , outer (backend), x, dx , outer_pullback_extras)
188179end
189180
190- function hvp! (f:: F , p, backend:: AbstractADType , x, v , extras:: HVPExtras ) where {F}
191- return hvp! (f, p, SecondOrder (backend, backend), x, v , extras)
181+ function hvp! (f:: F , p, backend:: AbstractADType , x, dx , extras:: HVPExtras ) where {F}
182+ return hvp! (f, p, SecondOrder (backend, backend), x, dx , extras)
192183end
193184
194185function hvp! (
195- f:: F , p, backend:: SecondOrder , x, v , extras:: ForwardOverForwardHVPExtras
186+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ForwardOverForwardHVPExtras
196187) where {F}
197- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
198- return pushforward! (
199- inner_gradient_closure, p, outer (backend), x, v, outer_pushforward_extras
200- )
188+ @compat (; inner_gradient, outer_pushforward_extras) = extras
189+ return pushforward! (inner_gradient, p, outer (backend), x, dx, outer_pushforward_extras)
201190end
202191
203192function hvp! (
204- f:: F , p, backend:: SecondOrder , x, v , extras:: ForwardOverReverseHVPExtras
193+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ForwardOverReverseHVPExtras
205194) where {F}
206- @compat (; inner_gradient_closure, outer_pushforward_extras) = extras
207- return pushforward! (
208- inner_gradient_closure, p, outer (backend), x, v, outer_pushforward_extras
209- )
195+ @compat (; inner_gradient, outer_pushforward_extras) = extras
196+ return pushforward! (inner_gradient, p, outer (backend), x, dx, outer_pushforward_extras)
210197end
211198
212199function hvp! (
213- f:: F , p, backend:: SecondOrder , x, v , extras:: ReverseOverForwardHVPExtras
200+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ReverseOverForwardHVPExtras
214201) where {F}
215- @compat (; inner_pushforward_closure_generator, outer_gradient_extras) = extras
216- inner_pushforward_closure = inner_pushforward_closure_generator (v )
217- return gradient! (inner_pushforward_closure , p, outer (backend), x, outer_gradient_extras)
202+ @compat (; outer_gradient_extras) = extras
203+ inner_pushforward = InnerPushforwardFixedSeed (f, nested ( inner (backend)), dx )
204+ return gradient! (inner_pushforward , p, outer (backend), x, outer_gradient_extras)
218205end
219206
220207function hvp! (
221- f:: F , p, backend:: SecondOrder , x, v , extras:: ReverseOverReverseHVPExtras
208+ f:: F , p, backend:: SecondOrder , x, dx , extras:: ReverseOverReverseHVPExtras
222209) where {F}
223- @compat (; inner_gradient_closure , outer_pullback_extras) = extras
224- return pullback! (inner_gradient_closure , p, outer (backend), x, v , outer_pullback_extras)
210+ @compat (; inner_gradient , outer_pullback_extras) = extras
211+ return pullback! (inner_gradient , p, outer (backend), x, dx , outer_pullback_extras)
225212end
0 commit comments