Skip to content

Commit 1804fbe

Browse files
committed
faster * and / of dual arr by constant mtx
1 parent 78c73af commit 1804fbe

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

src/dual.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,60 @@ function SpecialFunctions.logabsgamma(d::Dual{T,<:Real}) where {T}
765765
return (Dual{T}(y, SpecialFunctions.digamma(x) * partials(d)), s)
766766
end
767767

768+
# Efficient left multiplication/division of #
769+
# Dual array by a constant matrix #
770+
#-------------------------------------------#
771+
772+
# creates the copy of x and applies fvalue!(values) to its values, and fpartial!(partial, ix) to its partials
773+
function _map_dual_components(fvalue!, fpartial!, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
774+
N = npartials(DT)
775+
@assert Base.elsize(x) == sizeof(T) * (N + 1) # check that accessing x as Vector{T} is possible
776+
res = similar(x) # result
777+
t = similar(x, T) # temporary Array{T} for fvalue!/fpartial! application
778+
ystride = N + 1
779+
# y allows res to be accessed as Vector{T}
780+
y = Base.unsafe_wrap(Array, Base.unsafe_convert(Ptr{T}, res),
781+
length(res) * ystride, own=false)
782+
783+
# calculate res values
784+
@inbounds for (j, v) in enumerate(x)
785+
t[j] = value(v)
786+
end
787+
fvalue!(t)
788+
k = 1
789+
@inbounds for tt in t
790+
y[k] = tt
791+
k += ystride
792+
end
793+
794+
# calculate each res partial
795+
for i in 1:N
796+
@inbounds for (j, v) in enumerate(x)
797+
t[j] = partials(v, i)
798+
end
799+
fpartial!(t, i)
800+
k = i + 1
801+
@inbounds for tt in t
802+
y[k] = tt
803+
k += ystride
804+
end
805+
end
806+
807+
return res
808+
end
809+
810+
Base.:\(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
811+
UpperTriangular{<:LinearAlgebra.BlasFloat},
812+
Matrix{<:LinearAlgebra.BlasFloat}},
813+
x::AbstractVecOrMat{<:Dual}) =
814+
_map_dual_components(Base.Fix1(ldiv!, m), (x, _) -> ldiv!(m, x), x)
815+
816+
Base.:*(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
817+
UpperTriangular{<:LinearAlgebra.BlasFloat},
818+
Matrix{<:LinearAlgebra.BlasFloat}},
819+
x::AbstractVecOrMat{<:Dual}) =
820+
_map_dual_components(Base.Fix1(lmul!, m), (x, _) -> ldiv!(m, x), x)
821+
768822
###################
769823
# Pretty Printing #
770824
###################

0 commit comments

Comments
 (0)