@@ -768,53 +768,67 @@ end
768768# Efficient left multiplication/division of #
769769# Dual array by a constant matrix #
770770# -------------------------------------------#
771-
772- # creates the copy of x and applies fvalue!(values) to its values, and fpartial!( partial, ix ) to its partials
773- function _map_dual_components (fvalue!, fpartial!, x:: AbstractArray{DT} ) where DT <: Dual{<:Any, T} where T
771+ # creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
772+ # and fpartial!(partial(y, i), partial(y, i), i ) to its partials
773+ function _map_dual_components! (fvalue!, fpartial!, y :: AbstractArray{DT} , x:: AbstractArray{DT} ) where DT <: Dual{<:Any, T} where T
774774 N = npartials (DT)
775- res = similar (x) # result
776- t = similar (x , T) # temporary Array{T} for fvalue!/fpartial! application
775+ tx = similar (x, T)
776+ ty = similar (y , T) # temporary Array{T} for fvalue!/fpartial! application
777777 # y allows res to be accessed as Array{T}
778- y = reinterpret (reshape, T, res )
779- @assert size (y ) == (N + 1 , length (res) )
778+ yarr = reinterpret (reshape, T, y )
779+ @assert size (yarr ) == (N + 1 , size (y) ... )
780780 ystride = size (y, 1 )
781781
782782 # calculate res values
783783 @inbounds for (j, v) in enumerate (x)
784- t [j] = value (v)
784+ tx [j] = value (v)
785785 end
786- fvalue! (t )
786+ fvalue! (ty, tx )
787787 k = 1
788- @inbounds for tt in t
789- y [k] = tt
788+ @inbounds for tt in ty
789+ yarr [k] = tt
790790 k += ystride
791791 end
792792
793793 # calculate each res partial
794794 for i in 1 : N
795795 @inbounds for (j, v) in enumerate (x)
796- t [j] = partials (v, i)
796+ tx [j] = partials (v, i)
797797 end
798- fpartial! (t , i)
798+ fpartial! (ty, tx , i)
799799 k = i + 1
800- @inbounds for tt in t
801- y [k] = tt
800+ @inbounds for tt in ty
801+ yarr [k] = tt
802802 k += ystride
803803 end
804804 end
805805
806- return res
806+ return y
807807end
808808
809809for MT in (StridedMatrix{<: LinearAlgebra.BlasFloat },
810810 LowerTriangular{<: LinearAlgebra.BlasFloat },
811- UpperTriangular{<: LinearAlgebra.BlasFloat }),
812- XT in (StridedMatrix{<: Dual }, StridedVector{<: Dual })
813- @eval Base.:\ (m:: $MT , x:: $XT ) =
814- _map_dual_components (Base. Fix1 (ldiv!, m), (x, _) -> ldiv! (m, x), x)
811+ UpperTriangular{<: LinearAlgebra.BlasFloat })
812+
813+ @eval function Base.:\ (m:: $MT , x:: StridedVector{<:Dual} )
814+ T = valtype (eltype (x))
815+ ldiv! (m' , reinterpret (reshape, T, res))
816+ return res
817+ end
818+
819+ @eval Base.:\ (m:: $MT , x:: StridedMatrix{<:Dual} ) =
820+ _map_dual_components! ((x, _) -> ldiv! (m, x), (x, _, _) -> ldiv! (m, x), similar (x), x)
821+
822+ @eval function Base.:* (m:: $MT , x:: StridedVector{<:Dual} )
823+ T = valtype (eltype (x))
824+ res = similar (x, (size (m, 1 ),))
825+ mul! (reinterpret (reshape, T, res), reinterpret (reshape, T, x), m' )
826+ return res
827+ end
815828
816- @eval Base.:* (m:: $MT , x:: $XT ) =
817- _map_dual_components (Base. Fix1 (lmul!, m), (x, _) -> lmul! (m, x), x)
829+ @eval Base.:* (m:: $MT , x:: StridedMatrix{<:Dual} ) =
830+ _map_dual_components! ((y, x) -> mul! (y, m, x), (y, x, _) -> mul! (y, m, x),
831+ similar (x, (size (m, 1 ), size (x, 2 ))), x)
818832end
819833
820834# ##################
0 commit comments