Skip to content

Commit 8608f0b

Browse files
committed
fix *, optimize vector case
1 parent 0740903 commit 8608f0b

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

src/dual.jl

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -768,53 +768,67 @@ end
768768
# Efficient left multiplication/division of #
769769
# Dual array by a constant matrix #
770770
#-------------------------------------------#
771-
772-
# creates the copy of x and applies fvalue!(values) to its values, and fpartial!(partial, ix) to its partials
773-
function _map_dual_components(fvalue!, fpartial!, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
771+
# creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
772+
# and fpartial!(partial(y, i), partial(y, i), i) to its partials
773+
function _map_dual_components!(fvalue!, fpartial!, y::AbstractArray{DT}, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
774774
N = npartials(DT)
775-
res = similar(x) # result
776-
t = similar(x, T) # temporary Array{T} for fvalue!/fpartial! application
775+
tx = similar(x, T)
776+
ty = similar(y, T) # temporary Array{T} for fvalue!/fpartial! application
777777
# y allows res to be accessed as Array{T}
778-
y = reinterpret(reshape, T, res)
779-
@assert size(y) == (N + 1, length(res))
778+
yarr = reinterpret(reshape, T, y)
779+
@assert size(yarr) == (N + 1, size(y)...)
780780
ystride = size(y, 1)
781781

782782
# calculate res values
783783
@inbounds for (j, v) in enumerate(x)
784-
t[j] = value(v)
784+
tx[j] = value(v)
785785
end
786-
fvalue!(t)
786+
fvalue!(ty, tx)
787787
k = 1
788-
@inbounds for tt in t
789-
y[k] = tt
788+
@inbounds for tt in ty
789+
yarr[k] = tt
790790
k += ystride
791791
end
792792

793793
# calculate each res partial
794794
for i in 1:N
795795
@inbounds for (j, v) in enumerate(x)
796-
t[j] = partials(v, i)
796+
tx[j] = partials(v, i)
797797
end
798-
fpartial!(t, i)
798+
fpartial!(ty, tx, i)
799799
k = i + 1
800-
@inbounds for tt in t
801-
y[k] = tt
800+
@inbounds for tt in ty
801+
yarr[k] = tt
802802
k += ystride
803803
end
804804
end
805805

806-
return res
806+
return y
807807
end
808808

809809
for MT in (StridedMatrix{<:LinearAlgebra.BlasFloat},
810810
LowerTriangular{<:LinearAlgebra.BlasFloat},
811-
UpperTriangular{<:LinearAlgebra.BlasFloat}),
812-
XT in (StridedMatrix{<:Dual}, StridedVector{<:Dual})
813-
@eval Base.:\(m::$MT, x::$XT) =
814-
_map_dual_components(Base.Fix1(ldiv!, m), (x, _) -> ldiv!(m, x), x)
811+
UpperTriangular{<:LinearAlgebra.BlasFloat})
812+
813+
@eval function Base.:\(m::$MT, x::StridedVector{<:Dual})
814+
T = valtype(eltype(x))
815+
ldiv!(m', reinterpret(reshape, T, res))
816+
return res
817+
end
818+
819+
@eval Base.:\(m::$MT, x::StridedMatrix{<:Dual}) =
820+
_map_dual_components!((x, _) -> ldiv!(m, x), (x, _, _) -> ldiv!(m, x), similar(x), x)
821+
822+
@eval function Base.:*(m::$MT, x::StridedVector{<:Dual})
823+
T = valtype(eltype(x))
824+
res = similar(x, (size(m, 1),))
825+
mul!(reinterpret(reshape, T, res), reinterpret(reshape, T, x), m')
826+
return res
827+
end
815828

816-
@eval Base.:*(m::$MT, x::$XT) =
817-
_map_dual_components(Base.Fix1(lmul!, m), (x, _) -> lmul!(m, x), x)
829+
@eval Base.:*(m::$MT, x::StridedMatrix{<:Dual}) =
830+
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x),
831+
similar(x, (size(m, 1), size(x, 2))), x)
818832
end
819833

820834
###################

0 commit comments

Comments
 (0)