Skip to content

Commit f7b5556

Browse files
authored
fix jacobian in case where target functions's returned value is aliased with input (#106)
1 parent 2ec7168 commit f7b5556

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/finite_difference.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,23 @@ function finite_difference_jacobian!{R <: Number,
166166

167167
# Iterate over each dimension of the gradient separately.
168168
if dtype == :forward
169+
shifted_x = copy(x)
169170
for i = 1:n
170171
@forwardrule x[i] epsilon
171-
oldx = x[i]
172-
x[i] = oldx + epsilon
173-
f_xplusdx = f(x)
174-
x[i] = oldx
175-
J[:, i] = (f_xplusdx - f_x) / epsilon
172+
shifted_x[i] += epsilon
173+
J[:, i] = (f(shifted_x) - f_x) / epsilon
174+
shifted_x[i] = x[i]
176175
end
177176
elseif dtype == :central
177+
shifted_x_plus = copy(x)
178+
shifted_x_minus = copy(x)
178179
for i = 1:n
179180
@centralrule x[i] epsilon
180-
oldx = x[i]
181-
x[i] = oldx + epsilon
182-
f_xplusdx = f(x)
183-
x[i] = oldx - epsilon
184-
f_xminusdx = f(x)
185-
x[i] = oldx
186-
J[:, i] = (f_xplusdx - f_xminusdx) / (epsilon + epsilon)
181+
shifted_x_plus[i] += epsilon
182+
shifted_x_minus[i] -= epsilon
183+
J[:, i] = (f(shifted_x_plus) - f(shifted_x_minus)) / (epsilon + epsilon)
184+
shifted_x_plus[i] = x[i]
185+
shifted_x_minus[i] = x[i]
187186
end
188187
else
189188
error("dtype must :forward or :central")

test/derivative.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ f4(x::Vector) = (100.0 - x[1])^2 + (50.0 - x[2])^2
3131
@test norm(Calculus.gradient(f4, :central)([100.0, 50.0]) - [0.0, 0.0]) < 10e-4
3232
@test norm(Calculus.gradient(f4)([100.0, 50.0]) - [0.0, 0.0]) < 10e-4
3333

34+
#
35+
# jacobian()
36+
#
37+
38+
@test norm(Calculus.jacobian(identity, rand(3), :forward) - eye(3)) < 10e-4
39+
@test norm(Calculus.jacobian(identity, rand(3), :central) - eye(3)) < 10e-4
40+
3441
#
3542
# second_derivative()
3643
#

0 commit comments

Comments
 (0)