Skip to content

Commit de02f7e

Browse files
Merge pull request #7 from tshort/master
[for discussion] an abstract symbolic type and chain rule code
2 parents 8e213b8 + d0cccbb commit de02f7e

File tree

4 files changed

+400
-337
lines changed

4 files changed

+400
-337
lines changed

src/Calculus.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,6 @@ module Calculus
5757
include("check_derivative.jl")
5858
include("integrate.jl")
5959
include("symbolic.jl")
60+
include("differentiate.jl")
6061
include("deparse.jl")
6162
end

src/differentiate.jl

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
2+
export differentiate
3+
4+
#################################################################
5+
#
6+
# differentiate()
7+
# based on John's differentiate and this code, I think by Miles Lubin:
8+
# https://github.com/IainNZ/NLTester/blob/master/julia/nlp.jl#L74
9+
#
10+
#################################################################
11+
12+
differentiate(ex::SymbolicVariable, wrt::SymbolicVariable) = (ex == wrt) ? 1 : 0
13+
14+
differentiate(ex::Number, wrt::SymbolicVariable) = 0
15+
16+
function differentiate(ex::Expr,wrt)
17+
if ex.head != :call
18+
error("Unrecognized expression $ex")
19+
end
20+
simplify(differentiate(SymbolParameter(ex.args[1]), ex.args[2:end], wrt))
21+
end
22+
23+
differentiate{T}(x::SymbolParameter{T}, args, wrt) = error("Derivative of function " * string(T) * " not supported")
24+
25+
# The Power Rule:
26+
function differentiate(::SymbolParameter{:^}, args, wrt)
27+
x = args[1]
28+
y = args[2]
29+
xp = differentiate(x, wrt)
30+
yp = differentiate(y, wrt)
31+
if xp == 0 && yp == 0
32+
return 0
33+
elseif xp != 0 && yp == 0
34+
return :( $y * $xp * ($x ^ ($y - 1)) )
35+
else
36+
return :( $x ^ $y * ($xp * $y / $x + $yp * log($x)) )
37+
end
38+
end
39+
40+
function differentiate(::SymbolParameter{:+}, args, wrt)
41+
termdiffs = {:+}
42+
for y in args
43+
x = differentiate(y, wrt)
44+
if x != 0
45+
push!(termdiffs, x)
46+
end
47+
end
48+
if (length(termdiffs) == 1)
49+
return 0
50+
elseif (length(termdiffs) == 2)
51+
return termdiffs[2]
52+
else
53+
return Expr(:call, termdiffs...)
54+
end
55+
end
56+
57+
function differentiate(::SymbolParameter{:-}, args, wrt)
58+
termdiffs = {:-}
59+
# first term is special, can't be dropped
60+
term1 = differentiate(args[1], wrt)
61+
push!(termdiffs, term1)
62+
for y in args[2:end]
63+
x = differentiate(y, wrt)
64+
if x != 0
65+
push!(termdiffs, x)
66+
end
67+
end
68+
if term1 != 0 && length(termdiffs) == 2 && length(args) >= 2
69+
# if all of the terms but the first disappeared, we just return the first
70+
return term1
71+
elseif (term1 == 0 && length(termdiffs) == 2)
72+
return 0
73+
else
74+
return Expr(:call, termdiffs...)
75+
end
76+
end
77+
78+
# The Product Rule
79+
# d/dx (f * g) = (d/dx f) * g + f * (d/dx g)
80+
# d/dx (f * g * h) = (d/dx f) * g * h + f * (d/dx g) * h + ...
81+
function differentiate(::SymbolParameter{:*}, args, wrt)
82+
n = length(args)
83+
res_args = Array(Any, n)
84+
for i in 1:n
85+
new_args = Array(Any, n)
86+
for j in 1:n
87+
if j == i
88+
new_args[j] = differentiate(args[j], wrt)
89+
else
90+
new_args[j] = args[j]
91+
end
92+
end
93+
res_args[i] = Expr(:call, :*, new_args...)
94+
end
95+
return Expr(:call, :+, res_args...)
96+
end
97+
98+
# The Quotient Rule
99+
# d/dx (f / g) = ((d/dx f) * g - f * (d/dx g)) / g^2
100+
function differentiate(::SymbolParameter{:/}, args, wrt)
101+
x = args[1]
102+
y = args[2]
103+
xp = differentiate(x, wrt)
104+
yp = differentiate(y, wrt)
105+
if xp == 0 && yp == 0
106+
return 0
107+
elseif xp == 0
108+
return :( -$yp * $x )
109+
elseif yp == 0
110+
return :( $xp * $y )
111+
else
112+
return :( ($xp * $y - $x * $yp) / $y^2 )
113+
end
114+
end
115+
116+
117+
derivative_rules = [
118+
( :sqrt, :( xp / 2 / sqrt(x) ))
119+
( :cbrt, :( xp / 3 / cbrt(x)^2 ))
120+
( :square, :( xp * 2 * x ))
121+
( :log, :( xp / x ))
122+
( :log10, :( xp / x / log(10) ))
123+
( :log2, :( xp / x / log(2) ))
124+
( :log1p, :( xp / (x + 1) ))
125+
( :exp, :( xp * exp(x) ))
126+
( :exp2, :( xp * log(2) * exp2(x) ))
127+
( :expm1, :( xp * exp(x) ))
128+
( :sin, :( xp * cos(x) ))
129+
( :cos, :( -xp * sin(x) ))
130+
( :tan, :( xp * (1 + tan(x)^2) ))
131+
( :sec, :( xp * sec(x) * tan(x) ))
132+
( :csc, :( -xp * csc(x) * cot(x) ))
133+
( :cot, :( -xp * (1 + cot(x)^2) ))
134+
( :sind, :( xp * cosd(x) ))
135+
( :cosd, :( -xp * sind(x) ))
136+
( :tand, :( xp * (1 + tand(x)^2) ))
137+
( :secd, :( xp * secd(x) * tand(x) ))
138+
( :cscd, :( -xp * cscd(x) * cotd(x) ))
139+
( :cotd, :( -xp * (1 + cotd(x)^2) ))
140+
( :asin, :( xp / sqrt(1 - x^2) ))
141+
( :acos, :( -xp / sqrt(1 - x^2) ))
142+
( :atan, :( xp / (1 + x^2) ))
143+
( :asec, :( xp / abs(x) / sqrt(x^2 - 1) ))
144+
( :acsc, :( -xp / abs(x) / sqrt(x^2 - 1) ))
145+
( :acot, :( -xp / (1 + x^2) ))
146+
( :asind, :( xp * 180 / pi / sqrt(1 - x^2) ))
147+
( :acosd, :( -xp * 180 / pi / sqrt(1 - x^2) ))
148+
( :atand, :( xp * 180 / pi / (1 + x^2) ))
149+
( :asecd, :( xp * 180 / pi / abs(x) / sqrt(x^2 - 1) ))
150+
( :acscd, :( -xp * 180 / pi / abs(x) / sqrt(x^2 - 1) ))
151+
( :acotd, :( -xp * 180 / pi / (1 + x^2) ))
152+
( :sinh, :( xp * cosh(x) ))
153+
( :cosh, :( xp * sinh(x) ))
154+
( :tanh, :( xp * sech(x)^2 ))
155+
( :sech, :( -xp * tanh(x) * sech(x) ))
156+
( :csch, :( -xp * coth(x) * csch(x) ))
157+
( :coth, :( -xp * csch(x)^2 ))
158+
( :asinh, :( xp / sqrt(x^2 + 1) ))
159+
( :acosh, :( xp / sqrt(x^2 - 1) ))
160+
( :atanh, :( xp / (1 - x^2) ))
161+
( :asech, :( -xp / x / sqrt(1 - x^2) ))
162+
( :acsch, :( -xp / abs(x) / sqrt(1 + x^2) ))
163+
( :acoth, :( xp / (1 - x^2) ))
164+
( :erf, :( xp * 2 * exp(-square(x)) / sqrt(pi) ))
165+
( :erfc, :( -xp * 2 * exp(-square(x)) / sqrt(pi) ))
166+
( :erfi, :( xp * 2 * exp(square(x)) / sqrt(pi) ))
167+
( :gamma, :( xp * digamma(x) * gamma(x) ))
168+
( :lgamma, :( xp * digamma(x) ))
169+
( :airy, :( xp * airyprime(x) )) # note: only covers the 1-arg version
170+
( :airyprime, :( xp * airy(2, x) ))
171+
( :airyai, :( xp * airyaiprime(x) ))
172+
( :airybi, :( xp * airybiprime(x) ))
173+
( :airyaiprime, :( xp * x * airyai(x) ))
174+
( :airybiprime, :( xp * x * airybi(x) ))
175+
( :besselj0, :( -xp * besselj1(x) ))
176+
( :besselj1, :( xp * (besselj0(x) - besselj(2, x)) / 2 ))
177+
( :bessely0, :( -xp * bessely1(x) ))
178+
( :bessely1, :( xp * (bessely0(x) - bessely(2, x)) / 2 ))
179+
## ( :erfcx, :( xp * (2 * x * erfcx(x) - 2 / sqrt(pi)) )) # uncertain
180+
## ( :dawson, :( xp * (1 - 2x * dawson(x)) )) # uncertain
181+
182+
]
183+
184+
for (funsym, exp) in derivative_rules
185+
@eval function differentiate(::SymbolParameter{$(Meta.quot(funsym))}, args, wrt)
186+
x = args[1]
187+
xp = differentiate(x, wrt)
188+
if xp != 0
189+
return @sexpr($exp)
190+
else
191+
return 0
192+
end
193+
end
194+
end
195+
196+
derivative_rules_bessel = [
197+
( :besselj, :( xp * (besselj(nu - 1, x) - besselj(nu + 1, x)) / 2 ))
198+
( :besseli, :( xp * (besseli(nu - 1, x) + besseli(nu + 1, x)) / 2 ))
199+
( :bessely, :( xp * (bessely(nu - 1, x) - bessely(nu + 1, x)) / 2 ))
200+
( :besselk, :( -xp * (besselk(nu - 1, x) + besselk(nu + 1, x)) / 2 ))
201+
( :hankelh1, :( xp * (hankelh1(nu - 1, x) - hankelh1(nu + 1, x)) / 2 ))
202+
( :hankelh2, :( xp * (hankelh2(nu - 1, x) - hankelh2(nu + 1, x)) / 2 ))
203+
]
204+
205+
# 2-argument bessel functions
206+
for (funsym, exp) in derivative_rules_bessel
207+
@eval function differentiate(::SymbolParameter{$(Meta.quot(funsym))}, args, wrt)
208+
nu = args[1]
209+
x = args[2]
210+
xp = differentiate(x, wrt)
211+
if xp != 0
212+
return @sexpr($exp)
213+
else
214+
return 0
215+
end
216+
end
217+
end
218+
219+
### Other functions from julia/base/math.jl we might want to define
220+
### derivatives for. Some have two arguments.
221+
222+
## atan2
223+
## hypot
224+
## beta, lbeta, eta, zeta, digamma
225+
226+
function differentiate(ex::Expr, targets::Vector{Symbol})
227+
n = length(targets)
228+
exprs = Array(Expr, n)
229+
for i in 1:n
230+
exprs[i] = differentiate(ex, targets[i])
231+
end
232+
return exprs
233+
end
234+
235+
236+
differentiate(ex::Expr) = differentiate(ex, :x)
237+
238+
function differentiate(s::String, target::Symbol)
239+
differentiate(parse(s), target)
240+
end
241+
function differentiate(s::String, targets::Vector{Symbol})
242+
differentiate(parse(s), targets)
243+
end
244+
function differentiate(s::String, target::String)
245+
differentiate(parse(s), symbol(target))
246+
end
247+
function differentiate{T <: String}(s::String, targets::Vector{T})
248+
differentiate(parse(s), map(target -> symbol(target), targets))
249+
end
250+
function differentiate(s::String)
251+
differentiate(parse(s), :x)
252+
end
253+

0 commit comments

Comments
 (0)