Skip to content

Commit 4f1f951

Browse files
committed
faster * and / of dual arr by constant mtx
1 parent 61e4dd4 commit 4f1f951

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

src/dual.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,60 @@ function SpecialFunctions.gamma_inc(a::Real, d::Dual{T,<:Real}, ind::Integer) wh
773773
return (Dual{T}(p, ∂p), Dual{T}(q, -∂p))
774774
end
775775

776+
# Efficient left multiplication/division of #
777+
# Dual array by a constant matrix #
778+
#-------------------------------------------#
779+
780+
# creates the copy of x and applies fvalue!(values) to its values, and fpartial!(partial, ix) to its partials
781+
function _map_dual_components(fvalue!, fpartial!, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
782+
N = npartials(DT)
783+
@assert Base.elsize(x) == sizeof(T) * (N + 1) # check that accessing x as Vector{T} is possible
784+
res = similar(x) # result
785+
t = similar(x, T) # temporary Array{T} for fvalue!/fpartial! application
786+
ystride = N + 1
787+
# y allows res to be accessed as Vector{T}
788+
y = Base.unsafe_wrap(Array, Base.unsafe_convert(Ptr{T}, res),
789+
length(res) * ystride, own=false)
790+
791+
# calculate res values
792+
@inbounds for (j, v) in enumerate(x)
793+
t[j] = value(v)
794+
end
795+
fvalue!(t)
796+
k = 1
797+
@inbounds for tt in t
798+
y[k] = tt
799+
k += ystride
800+
end
801+
802+
# calculate each res partial
803+
for i in 1:N
804+
@inbounds for (j, v) in enumerate(x)
805+
t[j] = partials(v, i)
806+
end
807+
fpartial!(t, i)
808+
k = i + 1
809+
@inbounds for tt in t
810+
y[k] = tt
811+
k += ystride
812+
end
813+
end
814+
815+
return res
816+
end
817+
818+
Base.:\(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
819+
UpperTriangular{<:LinearAlgebra.BlasFloat},
820+
Matrix{<:LinearAlgebra.BlasFloat}},
821+
x::AbstractVecOrMat{<:Dual}) =
822+
_map_dual_components(Base.Fix1(ldiv!, m), (x, _) -> ldiv!(m, x), x)
823+
824+
Base.:*(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
825+
UpperTriangular{<:LinearAlgebra.BlasFloat},
826+
Matrix{<:LinearAlgebra.BlasFloat}},
827+
x::AbstractVecOrMat{<:Dual}) =
828+
_map_dual_components(Base.Fix1(lmul!, m), (x, _) -> ldiv!(m, x), x)
829+
776830
###################
777831
# Pretty Printing #
778832
###################

0 commit comments

Comments
 (0)