11# # Pushforward
22
3- struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2 } <: PushforwardExtras
3+ struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E1! } <: PushforwardExtras
44 y_prototype:: Y
55 jvp_exe:: E1
6- jvp_exe!:: E2
6+ jvp_exe!:: E1!
77end
88
99function DI. prepare_pushforward (f, :: AutoFastDifferentiation , x, dx)
7070
7171# # Pullback
7272
73- struct FastDifferentiationOneArgPullbackExtras{E1,E2 } <: PullbackExtras
73+ struct FastDifferentiationOneArgPullbackExtras{E1,E1! } <: PullbackExtras
7474 vjp_exe:: E1
75- vjp_exe!:: E2
75+ vjp_exe!:: E1!
7676end
7777
7878function DI. prepare_pullback (f, :: AutoFastDifferentiation , x, dy)
@@ -133,10 +133,10 @@ end
133133
134134# # Derivative
135135
136- struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2 } <: DerivativeExtras
136+ struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E1! } <: DerivativeExtras
137137 y_prototype:: Y
138138 der_exe:: E1
139- der_exe!:: E2
139+ der_exe!:: E1!
140140end
141141
142142function DI. prepare_derivative (f, :: AutoFastDifferentiation , x)
@@ -190,13 +190,12 @@ end
190190
191191# # Gradient
192192
193- struct FastDifferentiationOneArgGradientExtras{E1,E2 } <: GradientExtras
193+ struct FastDifferentiationOneArgGradientExtras{E1,E1! } <: GradientExtras
194194 jac_exe:: E1
195- jac_exe!:: E2
195+ jac_exe!:: E1!
196196end
197197
198198function DI. prepare_gradient (f, backend:: AutoFastDifferentiation , x)
199- y_prototype = f (x)
200199 x_var = make_variables (:x , size (x)... )
201200 y_var = f (x_var)
202201
@@ -241,10 +240,10 @@ end
241240
242241# # Jacobian
243242
244- struct FastDifferentiationOneArgJacobianExtras{Y,E1,E2 } <: JacobianExtras
243+ struct FastDifferentiationOneArgJacobianExtras{Y,E1,E1! } <: JacobianExtras
245244 y_prototype:: Y
246245 jac_exe:: E1
247- jac_exe!:: E2
246+ jac_exe!:: E1!
248247end
249248
250249function DI. prepare_jacobian (
@@ -307,34 +306,29 @@ end
307306
308307# # Second derivative
309308
310- struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E1! ,E2,E2!} < :
309+ struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,D ,E2,E2!} < :
311310 SecondDerivativeExtras
312311 y_prototype:: Y
313- der_exe:: E1
314- der_exe!:: E1!
312+ derivative_extras:: D
315313 der2_exe:: E2
316314 der2_exe!:: E2!
317315end
318316
319- function DI. prepare_second_derivative (f, :: AutoFastDifferentiation , x)
317+ function DI. prepare_second_derivative (f, backend :: AutoFastDifferentiation , x)
320318 y_prototype = f (x)
321319 x_var = only (make_variables (:x ))
322320 y_var = f (x_var)
323321
324322 x_vec_var = monovec (x_var)
325323 y_vec_var = y_var isa Number ? monovec (y_var) : vec (y_var)
326324
327- der_vec_var = derivative (y_vec_var, x_var)
328325 der2_vec_var = derivative (y_vec_var, x_var, x_var)
329-
330- der_exe = make_function (der_vec_var, x_vec_var; in_place= false )
331- der_exe! = make_function (der_vec_var, x_vec_var; in_place= true )
332-
333326 der2_exe = make_function (der2_vec_var, x_vec_var; in_place= false )
334327 der2_exe! = make_function (der2_vec_var, x_vec_var; in_place= true )
335328
329+ derivative_extras = DI. prepare_derivative (f, backend, x)
336330 return FastDifferentiationAllocatingSecondDerivativeExtras (
337- y_prototype, der_exe, der_exe! , der2_exe, der2_exe!
331+ y_prototype, derivative_extras , der2_exe, der2_exe!
338332 )
339333end
340334
@@ -364,20 +358,13 @@ end
364358
365359function DI. value_derivative_and_second_derivative (
366360 f,
367- :: AutoFastDifferentiation ,
361+ backend :: AutoFastDifferentiation ,
368362 x,
369363 extras:: FastDifferentiationAllocatingSecondDerivativeExtras ,
370364)
371- y = f (x)
372- if extras. y_prototype isa Number
373- der = only (extras. der_exe (monovec (x)))
374- der2 = only (extras. der2_exe (monovec (x)))
375- return y, der, der2
376- else
377- der = reshape (extras. der_exe (monovec (x)), size (extras. y_prototype))
378- der2 = reshape (extras. der2_exe (monovec (x)), size (extras. y_prototype))
379- return y, der, der2
380- end
365+ y, der = DI. value_and_derivative (f, backend, x, extras. derivative_extras)
366+ der2 = DI. second_derivative (f, backend, x, extras)
367+ return y, der, der2
381368end
382369
383370function DI. value_derivative_and_second_derivative! (
@@ -388,17 +375,16 @@ function DI.value_derivative_and_second_derivative!(
388375 x,
389376 extras:: FastDifferentiationAllocatingSecondDerivativeExtras ,
390377)
391- y = f (x)
392- extras. der_exe! (vec (der), monovec (x))
393- extras. der2_exe! (vec (der2), monovec (x))
378+ y, _ = DI. value_and_derivative! (f, der, backend, x, extras. derivative_extras)
379+ DI. second_derivative! (f, der2, backend, x, extras)
394380 return y, der, der2
395381end
396382
397383# # HVP
398384
399- struct FastDifferentiationHVPExtras{E1 ,E2} <: HVPExtras
400- hvp_exe:: E1
401- hvp_exe!:: E2
385+ struct FastDifferentiationHVPExtras{E2 ,E2! } <: HVPExtras
386+ hvp_exe:: E2
387+ hvp_exe!:: E2!
402388end
403389
404390function DI. prepare_hvp (f, :: AutoFastDifferentiation , x, v)
@@ -428,24 +414,30 @@ end
428414
429415# # Hessian
430416
431- struct FastDifferentiationHessianExtras{E1,E2} <: HessianExtras
432- hess_exe:: E1
433- hess_exe!:: E2
417+ struct FastDifferentiationHessianExtras{G,E2,E2!} <: HessianExtras
418+ gradient_extras:: G
419+ hess_exe:: E2
420+ hess_exe!:: E2!
434421end
435422
436423function DI. prepare_hessian (
437424 f, backend:: Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}} , x
438425)
439- x_vec_var = make_variables (:x , size (x)... )
440- y_vec_var = f (x_vec_var)
426+ x_var = make_variables (:x , size (x)... )
427+ y_var = f (x_var)
428+
429+ x_vec_var = vec (x_var)
430+
441431 hess_var = if backend isa AutoSparse
442- sparse_hessian (y_vec_var, vec ( x_vec_var) )
432+ sparse_hessian (y_var, x_vec_var)
443433 else
444- hessian (y_vec_var, vec ( x_vec_var) )
434+ hessian (y_var, x_vec_var)
445435 end
446- hess_exe = make_function (hess_var, vec (x_vec_var); in_place= false )
447- hess_exe! = make_function (hess_var, vec (x_vec_var); in_place= true )
448- return FastDifferentiationHessianExtras (hess_exe, hess_exe!)
436+ hess_exe = make_function (hess_var, x_vec_var; in_place= false )
437+ hess_exe! = make_function (hess_var, x_vec_var; in_place= true )
438+
439+ gradient_extras = DI. prepare_gradient (f, maybe_dense_ad (backend), x)
440+ return FastDifferentiationHessianExtras (gradient_extras, hess_exe, hess_exe!)
449441end
450442
451443function DI. hessian (
@@ -467,3 +459,29 @@ function DI.hessian!(
467459 extras. hess_exe! (hess, vec (x))
468460 return hess
469461end
462+
463+ function DI. value_gradient_and_hessian (
464+ f,
465+ backend:: Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}} ,
466+ x,
467+ extras:: FastDifferentiationHessianExtras ,
468+ )
469+ y, grad = DI. value_and_gradient (f, maybe_dense_ad (backend), x, extras. gradient_extras)
470+ hess = DI. hessian (f, backend, x, extras)
471+ return y, grad, hess
472+ end
473+
474+ function DI. value_gradient_and_hessian! (
475+ f,
476+ grad,
477+ hess,
478+ backend:: Union{AutoFastDifferentiation,AutoSparse{<:AutoFastDifferentiation}} ,
479+ x,
480+ extras:: FastDifferentiationHessianExtras ,
481+ )
482+ y, _ = DI. value_and_gradient! (
483+ f, grad, maybe_dense_ad (backend), x, extras. gradient_extras
484+ )
485+ DI. hessian! (f, hess, backend, x, extras)
486+ return y, grad, hess
487+ end
0 commit comments