@@ -799,53 +799,67 @@ end
799799# Efficient left multiplication/division of #
800800# Dual array by a constant matrix #
801801# -------------------------------------------#
802-
803- # creates the copy of x and applies fvalue!(values) to its values, and fpartial!( partial, ix ) to its partials
804- function _map_dual_components (fvalue!, fpartial!, x:: AbstractArray{DT} ) where DT <: Dual{<:Any, T} where T
802+ # creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
803+ # and fpartial!(partial(y, i), partial(y, i), i ) to its partials
804+ function _map_dual_components! (fvalue!, fpartial!, y :: AbstractArray{DT} , x:: AbstractArray{DT} ) where DT <: Dual{<:Any, T} where T
805805 N = npartials (DT)
806- res = similar (x) # result
807- t = similar (x , T) # temporary Array{T} for fvalue!/fpartial! application
806+ tx = similar (x, T)
807+ ty = similar (y , T) # temporary Array{T} for fvalue!/fpartial! application
808808 # y allows res to be accessed as Array{T}
809- y = reinterpret (reshape, T, res )
810- @assert size (y ) == (N + 1 , length (res) )
809+ yarr = reinterpret (reshape, T, y )
810+ @assert size (yarr ) == (N + 1 , size (y) ... )
811811 ystride = size (y, 1 )
812812
813813 # calculate res values
814814 @inbounds for (j, v) in enumerate (x)
815- t [j] = value (v)
815+ tx [j] = value (v)
816816 end
817- fvalue! (t )
817+ fvalue! (ty, tx )
818818 k = 1
819- @inbounds for tt in t
820- y [k] = tt
819+ @inbounds for tt in ty
820+ yarr [k] = tt
821821 k += ystride
822822 end
823823
824824 # calculate each res partial
825825 for i in 1 : N
826826 @inbounds for (j, v) in enumerate (x)
827- t [j] = partials (v, i)
827+ tx [j] = partials (v, i)
828828 end
829- fpartial! (t , i)
829+ fpartial! (ty, tx , i)
830830 k = i + 1
831- @inbounds for tt in t
832- y [k] = tt
831+ @inbounds for tt in ty
832+ yarr [k] = tt
833833 k += ystride
834834 end
835835 end
836836
837- return res
837+ return y
838838end
839839
840840for MT in (StridedMatrix{<: LinearAlgebra.BlasFloat },
841841 LowerTriangular{<: LinearAlgebra.BlasFloat },
842- UpperTriangular{<: LinearAlgebra.BlasFloat }),
843- XT in (StridedMatrix{<: Dual }, StridedVector{<: Dual })
844- @eval Base.:\ (m:: $MT , x:: $XT ) =
845- _map_dual_components (Base. Fix1 (ldiv!, m), (x, _) -> ldiv! (m, x), x)
842+ UpperTriangular{<: LinearAlgebra.BlasFloat })
843+
844+ @eval function Base.:\ (m:: $MT , x:: StridedVector{<:Dual} )
845+ T = valtype (eltype (x))
846+ ldiv! (m' , reinterpret (reshape, T, res))
847+ return res
848+ end
849+
850+ @eval Base.:\ (m:: $MT , x:: StridedMatrix{<:Dual} ) =
851+ _map_dual_components! ((x, _) -> ldiv! (m, x), (x, _, _) -> ldiv! (m, x), similar (x), x)
852+
853+ @eval function Base.:* (m:: $MT , x:: StridedVector{<:Dual} )
854+ T = valtype (eltype (x))
855+ res = similar (x, (size (m, 1 ),))
856+ mul! (reinterpret (reshape, T, res), reinterpret (reshape, T, x), m' )
857+ return res
858+ end
846859
847- @eval Base.:* (m:: $MT , x:: $XT ) =
848- _map_dual_components (Base. Fix1 (lmul!, m), (x, _) -> lmul! (m, x), x)
860+ @eval Base.:* (m:: $MT , x:: StridedMatrix{<:Dual} ) =
861+ _map_dual_components! ((y, x) -> mul! (y, m, x), (y, x, _) -> mul! (y, m, x),
862+ similar (x, (size (m, 1 ), size (x, 2 ))), x)
849863end
850864
851865# ##################
0 commit comments