@@ -66,21 +66,8 @@ function value_and_gradient(ab::AbstractBackend, f, xs...)
6666 return value, reshape .(adjoint .(jacs),size .(xs))
6767end
6868function value_and_jacobian (ab:: AbstractBackend , f, xs... )
69- local value
70- primalcalled = false
71- if lowest (ab) isa AbstractFiniteDifference
72- value = primal_value (ab, nothing , f, xs)
73- primalcalled = true
74- end
75- jacs = jacobian (lowest (ab), (_xs... ,) -> begin
76- v = f (_xs... )
77- if ! primalcalled
78- value = primal_value (ab, v, f, xs)
79- primalcalled = true
80- end
81- return v
82- end , xs... )
83-
69+ value = f (xs... )
70+ jacs = jacobian (lowest (ab), f, xs... )
8471 return value, jacs
8572end
8673function value_and_hessian (ab:: AbstractBackend , f, x)
@@ -89,71 +76,54 @@ function value_and_hessian(ab::AbstractBackend, f, x)
8976 x = only (x)
9077 end
9178
92- local value
93- primalcalled = false
94- if ab isa AbstractFiniteDifference
95- value = primal_value (ab, nothing , f, (x,))
96- primalcalled = true
97- end
79+ value = f (x)
9880 hess = jacobian (second_lowest (ab), _x -> begin
99- v, g = value_and_gradient (lowest (ab), f, _x)
100- if ! primalcalled
101- value = primal_value (ab, v, f, (x,))
102- primalcalled = true
103- end
81+ g = gradient (lowest (ab), f, _x)
10482 return g[1 ] # gradient returns a tuple
10583 end , x)
84+
10685 return value, hess
10786end
10887function value_and_hessian (ab:: HigherOrderBackend , f, x)
10988 if x isa Tuple
11089 # only support computation of Hessian for functions with single input argument
11190 x = only (x)
11291 end
113- local value
114- primalcalled = false
92+
93+ value = f (x)
11594 hess = jacobian (second_lowest (ab), (_x,) -> begin
116- v, g = value_and_gradient (lowest (ab), f, _x)
117- if ! primalcalled
118- value = primal_value (ab, v, f, (x,))
119- primalcalled = true
120- end
95+ g = gradient (lowest (ab), f, _x)
12196 return g[1 ] # gradient returns a tuple
12297 end , x)
98+
12399 return value, hess
124100end
125101function value_gradient_and_hessian (ab:: AbstractBackend , f, x)
126102 if x isa Tuple
127103 # only support computation of Hessian for functions with single input argument
128104 x = only (x)
129105 end
130- local value
131- primalcalled = false
106+
107+ value = f (x)
132108 grads, hess = value_and_jacobian (second_lowest (ab), _x -> begin
133- v, g = value_and_gradient (lowest (ab), f, _x)
134- if ! primalcalled
135- value = primal_value (second_lowest (ab), v, f, (x,))
136- primalcalled = true
137- end
109+ g = gradient (lowest (ab), f, _x)
138110 return g[1 ] # gradient returns a tuple
139111 end , x)
112+
140113 return value, (grads,), hess
141114end
142115function value_gradient_and_hessian (ab:: HigherOrderBackend , f, x)
143116 if x isa Tuple
144117 # only support computation of Hessian for functions with single input argument
145118 x = only (x)
146119 end
147- local value
148- primalcalled = false
120+
121+ value = f (x)
149122 grads, hess = value_and_jacobian (second_lowest (ab), _x -> begin
150- v, g = value_and_gradient (lowest (ab), f, _x)
151- if ! primalcalled
152- value = primal_value (second_lowest (ab), v, f, (x,))
153- primalcalled = true
154- end
123+ g = gradient (lowest (ab), f, _x)
155124 return g[1 ] # gradient returns a tuple
156125 end , x)
126+
157127 return value, (grads,), hess
158128end
159129
@@ -180,26 +150,16 @@ function value_and_pushforward_function(
180150 f,
181151 xs... ,
182152)
183- return (ds) -> begin
153+ n = length (xs)
154+ value = f (xs... )
155+ pf_function = pushforward_function (lowest (ab), f, xs... )
156+
157+ return ds -> begin
184158 if ! (ds isa Tuple)
185159 ds = (ds,)
186160 end
187- @assert length (ds) == length (xs)
188- local value
189- primalcalled = false
190- if ab isa AbstractFiniteDifference
191- value = primal_value (ab, nothing , f, xs)
192- primalcalled = true
193- end
194- pf = pushforward_function (lowest (ab), (_xs... ,) -> begin
195- vs = f (_xs... )
196- if ! primalcalled
197- value = primal_value (lowest (ab), vs, f, xs)
198- primalcalled = true
199- end
200- return vs
201- end , xs... )(ds)
202-
161+ @assert length (ds) == n
162+ pf = pf_function (ds)
203163 return value, pf
204164 end
205165end
@@ -476,12 +436,6 @@ macro primitive(expr)
476436 return define_pushforward_function_and_friends (fdef) |> esc
477437 elseif name == :value_and_pullback_function
478438 return define_value_and_pullback_function_and_friends (fdef) |> esc
479- elseif name == :jacobian
480- return define_jacobian_and_friends (fdef) |> esc
481- elseif name == :primal_value
482- return define_primal_value (fdef) |> esc
483- elseif name == :pullback_function
484- return define_pullback_function_and_friends (fdef) |> esc
485439 else
486440 throw (" Unsupported AD primitive." )
487441 end
0 commit comments