Skip to content

Commit faa7307

Browse files
committed
fix *, optimize vector case
1 parent aef0df9 commit faa7307

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

src/dual.jl

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -799,53 +799,67 @@ end
799799
# Efficient left multiplication/division of #
800800
# Dual array by a constant matrix #
801801
#-------------------------------------------#
802-
803-
# creates the copy of x and applies fvalue!(values) to its values, and fpartial!(partial, ix) to its partials
804-
function _map_dual_components(fvalue!, fpartial!, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
802+
# creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
803+
# and fpartial!(partial(y, i), partial(y, i), i) to its partials
804+
function _map_dual_components!(fvalue!, fpartial!, y::AbstractArray{DT}, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
805805
N = npartials(DT)
806-
res = similar(x) # result
807-
t = similar(x, T) # temporary Array{T} for fvalue!/fpartial! application
806+
tx = similar(x, T)
807+
ty = similar(y, T) # temporary Array{T} for fvalue!/fpartial! application
808808
# y allows res to be accessed as Array{T}
809-
y = reinterpret(reshape, T, res)
810-
@assert size(y) == (N + 1, length(res))
809+
yarr = reinterpret(reshape, T, y)
810+
@assert size(yarr) == (N + 1, size(y)...)
811811
ystride = size(y, 1)
812812

813813
# calculate res values
814814
@inbounds for (j, v) in enumerate(x)
815-
t[j] = value(v)
815+
tx[j] = value(v)
816816
end
817-
fvalue!(t)
817+
fvalue!(ty, tx)
818818
k = 1
819-
@inbounds for tt in t
820-
y[k] = tt
819+
@inbounds for tt in ty
820+
yarr[k] = tt
821821
k += ystride
822822
end
823823

824824
# calculate each res partial
825825
for i in 1:N
826826
@inbounds for (j, v) in enumerate(x)
827-
t[j] = partials(v, i)
827+
tx[j] = partials(v, i)
828828
end
829-
fpartial!(t, i)
829+
fpartial!(ty, tx, i)
830830
k = i + 1
831-
@inbounds for tt in t
832-
y[k] = tt
831+
@inbounds for tt in ty
832+
yarr[k] = tt
833833
k += ystride
834834
end
835835
end
836836

837-
return res
837+
return y
838838
end
839839

840840
for MT in (StridedMatrix{<:LinearAlgebra.BlasFloat},
841841
LowerTriangular{<:LinearAlgebra.BlasFloat},
842-
UpperTriangular{<:LinearAlgebra.BlasFloat}),
843-
XT in (StridedMatrix{<:Dual}, StridedVector{<:Dual})
844-
@eval Base.:\(m::$MT, x::$XT) =
845-
_map_dual_components(Base.Fix1(ldiv!, m), (x, _) -> ldiv!(m, x), x)
842+
UpperTriangular{<:LinearAlgebra.BlasFloat})
843+
844+
@eval function Base.:\(m::$MT, x::StridedVector{<:Dual})
845+
T = valtype(eltype(x))
846+
ldiv!(m', reinterpret(reshape, T, res))
847+
return res
848+
end
849+
850+
@eval Base.:\(m::$MT, x::StridedMatrix{<:Dual}) =
851+
_map_dual_components!((x, _) -> ldiv!(m, x), (x, _, _) -> ldiv!(m, x), similar(x), x)
852+
853+
@eval function Base.:*(m::$MT, x::StridedVector{<:Dual})
854+
T = valtype(eltype(x))
855+
res = similar(x, (size(m, 1),))
856+
mul!(reinterpret(reshape, T, res), reinterpret(reshape, T, x), m')
857+
return res
858+
end
846859

847-
@eval Base.:*(m::$MT, x::$XT) =
848-
_map_dual_components(Base.Fix1(lmul!, m), (x, _) -> lmul!(m, x), x)
860+
@eval Base.:*(m::$MT, x::StridedMatrix{<:Dual}) =
861+
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x),
862+
similar(x, (size(m, 1), size(x, 2))), x)
849863
end
850864

851865
###################

0 commit comments

Comments
 (0)