@@ -776,53 +776,67 @@ end
776776# Efficient left multiplication/division of #
777777# Dual array by a constant matrix #
778778# -------------------------------------------#
779-
780- # creates the copy of x and applies fvalue!(values) to its values, and fpartial!( partial, ix ) to its partials
781- function _map_dual_components (fvalue!, fpartial!, x:: AbstractArray{DT} ) where DT <: Dual{<:Any, T} where T
779+ # creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
780+ # and fpartial!(partial(y, i), partial(y, i), i ) to its partials
781+ function _map_dual_components! (fvalue!, fpartial!, y :: AbstractArray{DT} , x:: AbstractArray{DT} ) where DT <: Dual{<:Any, T} where T
782782 N = npartials (DT)
783- res = similar (x) # result
784- t = similar (x , T) # temporary Array{T} for fvalue!/fpartial! application
783+ tx = similar (x, T)
784+ ty = similar (y , T) # temporary Array{T} for fvalue!/fpartial! application
785785 # y allows res to be accessed as Array{T}
786- y = reinterpret (reshape, T, res )
787- @assert size (y ) == (N + 1 , length (res) )
786+ yarr = reinterpret (reshape, T, y )
787+ @assert size (yarr ) == (N + 1 , size (y) ... )
788788 ystride = size (y, 1 )
789789
790790 # calculate res values
791791 @inbounds for (j, v) in enumerate (x)
792- t [j] = value (v)
792+ tx [j] = value (v)
793793 end
794- fvalue! (t )
794+ fvalue! (ty, tx )
795795 k = 1
796- @inbounds for tt in t
797- y [k] = tt
796+ @inbounds for tt in ty
797+ yarr [k] = tt
798798 k += ystride
799799 end
800800
801801 # calculate each res partial
802802 for i in 1 : N
803803 @inbounds for (j, v) in enumerate (x)
804- t [j] = partials (v, i)
804+ tx [j] = partials (v, i)
805805 end
806- fpartial! (t , i)
806+ fpartial! (ty, tx , i)
807807 k = i + 1
808- @inbounds for tt in t
809- y [k] = tt
808+ @inbounds for tt in ty
809+ yarr [k] = tt
810810 k += ystride
811811 end
812812 end
813813
814- return res
814+ return y
815815end
816816
817817for MT in (StridedMatrix{<: LinearAlgebra.BlasFloat },
818818 LowerTriangular{<: LinearAlgebra.BlasFloat },
819- UpperTriangular{<: LinearAlgebra.BlasFloat }),
820- XT in (StridedMatrix{<: Dual }, StridedVector{<: Dual })
821- @eval Base.:\ (m:: $MT , x:: $XT ) =
822- _map_dual_components (Base. Fix1 (ldiv!, m), (x, _) -> ldiv! (m, x), x)
819+ UpperTriangular{<: LinearAlgebra.BlasFloat })
820+
821+ @eval function Base.:\ (m:: $MT , x:: StridedVector{<:Dual} )
822+ T = valtype (eltype (x))
823+ ldiv! (m' , reinterpret (reshape, T, res))
824+ return res
825+ end
826+
827+ @eval Base.:\ (m:: $MT , x:: StridedMatrix{<:Dual} ) =
828+ _map_dual_components! ((x, _) -> ldiv! (m, x), (x, _, _) -> ldiv! (m, x), similar (x), x)
829+
830+ @eval function Base.:* (m:: $MT , x:: StridedVector{<:Dual} )
831+ T = valtype (eltype (x))
832+ res = similar (x, (size (m, 1 ),))
833+ mul! (reinterpret (reshape, T, res), reinterpret (reshape, T, x), m' )
834+ return res
835+ end
823836
824- @eval Base.:* (m:: $MT , x:: $XT ) =
825- _map_dual_components (Base. Fix1 (lmul!, m), (x, _) -> lmul! (m, x), x)
837+ @eval Base.:* (m:: $MT , x:: StridedMatrix{<:Dual} ) =
838+ _map_dual_components! ((y, x) -> mul! (y, m, x), (y, x, _) -> mul! (y, m, x),
839+ similar (x, (size (m, 1 ), size (x, 2 ))), x)
826840end
827841
828842# ##################
0 commit comments