Skip to content

Commit ad07f60

Browse files
committed
fix *, optimize vector case
1 parent e213856 commit ad07f60

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

src/dual.jl

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -776,53 +776,67 @@ end
776776
# Efficient left multiplication/division of #
777777
# Dual array by a constant matrix #
778778
#-------------------------------------------#
779-
780-
# creates the copy of x and applies fvalue!(values) to its values, and fpartial!(partial, ix) to its partials
781-
function _map_dual_components(fvalue!, fpartial!, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
779+
# creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
780+
# and fpartial!(partial(y, i), partial(y, i), i) to its partials
781+
function _map_dual_components!(fvalue!, fpartial!, y::AbstractArray{DT}, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
782782
N = npartials(DT)
783-
res = similar(x) # result
784-
t = similar(x, T) # temporary Array{T} for fvalue!/fpartial! application
783+
tx = similar(x, T)
784+
ty = similar(y, T) # temporary Array{T} for fvalue!/fpartial! application
785785
# y allows res to be accessed as Array{T}
786-
y = reinterpret(reshape, T, res)
787-
@assert size(y) == (N + 1, length(res))
786+
yarr = reinterpret(reshape, T, y)
787+
@assert size(yarr) == (N + 1, size(y)...)
788788
ystride = size(y, 1)
789789

790790
# calculate res values
791791
@inbounds for (j, v) in enumerate(x)
792-
t[j] = value(v)
792+
tx[j] = value(v)
793793
end
794-
fvalue!(t)
794+
fvalue!(ty, tx)
795795
k = 1
796-
@inbounds for tt in t
797-
y[k] = tt
796+
@inbounds for tt in ty
797+
yarr[k] = tt
798798
k += ystride
799799
end
800800

801801
# calculate each res partial
802802
for i in 1:N
803803
@inbounds for (j, v) in enumerate(x)
804-
t[j] = partials(v, i)
804+
tx[j] = partials(v, i)
805805
end
806-
fpartial!(t, i)
806+
fpartial!(ty, tx, i)
807807
k = i + 1
808-
@inbounds for tt in t
809-
y[k] = tt
808+
@inbounds for tt in ty
809+
yarr[k] = tt
810810
k += ystride
811811
end
812812
end
813813

814-
return res
814+
return y
815815
end
816816

817817
for MT in (StridedMatrix{<:LinearAlgebra.BlasFloat},
818818
LowerTriangular{<:LinearAlgebra.BlasFloat},
819-
UpperTriangular{<:LinearAlgebra.BlasFloat}),
820-
XT in (StridedMatrix{<:Dual}, StridedVector{<:Dual})
821-
@eval Base.:\(m::$MT, x::$XT) =
822-
_map_dual_components(Base.Fix1(ldiv!, m), (x, _) -> ldiv!(m, x), x)
819+
UpperTriangular{<:LinearAlgebra.BlasFloat})
820+
821+
@eval function Base.:\(m::$MT, x::StridedVector{<:Dual})
822+
T = valtype(eltype(x))
823+
ldiv!(m', reinterpret(reshape, T, res))
824+
return res
825+
end
826+
827+
@eval Base.:\(m::$MT, x::StridedMatrix{<:Dual}) =
828+
_map_dual_components!((x, _) -> ldiv!(m, x), (x, _, _) -> ldiv!(m, x), similar(x), x)
829+
830+
@eval function Base.:*(m::$MT, x::StridedVector{<:Dual})
831+
T = valtype(eltype(x))
832+
res = similar(x, (size(m, 1),))
833+
mul!(reinterpret(reshape, T, res), reinterpret(reshape, T, x), m')
834+
return res
835+
end
823836

824-
@eval Base.:*(m::$MT, x::$XT) =
825-
_map_dual_components(Base.Fix1(lmul!, m), (x, _) -> lmul!(m, x), x)
837+
@eval Base.:*(m::$MT, x::StridedMatrix{<:Dual}) =
838+
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x),
839+
similar(x, (size(m, 1), size(x, 2))), x)
826840
end
827841

828842
###################

0 commit comments

Comments
 (0)