@@ -52,6 +52,15 @@ for f in [:rowvals, :nonzeros, :getcolptr]
5252 @eval SparseArrays.$ (f)(A:: ThreadedSparseMatrixCSC ) = SparseArrays.$ (f)(A. A)
5353end
5454
55+
56+ # sparse * sparse multiplications are not (currently) threaded, but we want to keep the return type
57+ for (T1,t1) in ((ThreadedSparseMatrixCSC,identity), (Adjoint{<: Any ,<: ThreadedSparseMatrixCSC },adjoint), (Transpose{<: Any ,<: ThreadedSparseMatrixCSC },transpose))
58+ for (T2,t2) in ((ThreadedSparseMatrixCSC,identity), (Adjoint{<: Any ,<: ThreadedSparseMatrixCSC },adjoint), (Transpose{<: Any ,<: ThreadedSparseMatrixCSC },transpose))
59+ @eval Base.:(* )(A:: $T1 , B:: $T2 ) = ThreadedSparseMatrixCSC ($ t1 ($ t1 (A). A)* $ t2 ($ t2 (B). A))
60+ end
61+ end
62+
63+
5564function mul! (C:: StridedVecOrMat , A:: ThreadedSparseMatrixCSC , B:: Union{StridedVector,AdjOrTransDenseMatrix} , α:: Number , β:: Number )
5665 size (A, 2 ) == size (B, 1 ) || throw (DimensionMismatch ())
5766 size (A, 1 ) == size (C, 1 ) || throw (DimensionMismatch ())
@@ -63,9 +72,9 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
6372 end
6473 @sync for r in RangeIterator (size (C,2 ), Threads. nthreads ())
6574 Threads. @spawn for k in r
66- @inbounds for col = 1 : size (A, 2 )
75+ @inbounds for col in 1 : size (A, 2 )
6776 αxj = B[col,k] * α
68- for j = getcolptr (A)[col] : ( getcolptr (A)[col + 1 ] - 1 )
77+ for j in nzrange (A, col )
6978 C[rv[j], k] += nzv[j]* αxj
7079 end
7180 end
@@ -74,98 +83,53 @@ function mul!(C::StridedVecOrMat, A::ThreadedSparseMatrixCSC, B::Union{StridedVe
7483 C
7584end
7685
77- function mul! (C:: StridedVecOrMat , adjA:: Adjoint{<:Any,<:ThreadedSparseMatrixCSC} , B:: AdjOrTransDenseMatrix , α:: Number , β:: Number )
78- A = adjA. parent
79- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
80- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
81- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
82- colptrA = getcolptr (A)
83- nzv = nonzeros (A)
84- rv = rowvals (A)
85- if β != 1
86- β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
87- end
88- @sync for r in RangeIterator (size (C,2 ), Threads. nthreads ())
89- Threads. @spawn for k in r
90- @inbounds for col = 1 : size (A, 2 )
91- tmp = zero (eltype (C))
92- for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
93- tmp += adjoint (nzv[j])* B[rv[j],k]
94- end
95- C[col,k] += tmp * α
96- end
86+ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
87+ @eval function mul! (C:: StridedVecOrMat , xA:: $T{<:Any,<:ThreadedSparseMatrixCSC} , B:: AdjOrTransDenseMatrix , α:: Number , β:: Number )
88+ A = xA. parent
89+ size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
90+ size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
91+ size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
92+ nzv = nonzeros (A)
93+ rv = rowvals (A)
94+ if β != 1
95+ β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
9796 end
98- end
99- C
100- end
101- function mul! (C:: StridedVecOrMat , adjA:: Adjoint{<:Any,<:ThreadedSparseMatrixCSC} , B:: StridedVector , α:: Number , β:: Number )
102- A = adjA. parent
103- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
104- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
105- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
106- @assert size (B,2 )== 1
107- colptrA = getcolptr (A)
108- nzv = nonzeros (A)
109- rv = rowvals (A)
110- if β != 1
111- β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
112- end
113- @sync for r in RangeIterator (size (A,2 ), Threads. nthreads ())
114- Threads. @spawn @inbounds for col = r
115- tmp = zero (eltype (C))
116- for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
117- tmp += adjoint (nzv[j])* B[rv[j]]
97+ @sync for r in RangeIterator (size (C,2 ), Threads. nthreads ())
98+ Threads. @spawn for k in r
99+ @inbounds for col in 1 : size (A, 2 )
100+ tmp = zero (eltype (C))
101+ for j in nzrange (A, col)
102+ tmp += $ t (nzv[j])* B[rv[j],k]
103+ end
104+ C[col,k] += tmp * α
105+ end
118106 end
119- C[col] += tmp * α
120107 end
108+ C
121109 end
122- C
123- end
124110
125- function mul! (C:: StridedVecOrMat , transA :: Transpose {<:Any,<:ThreadedSparseMatrixCSC} , B:: AdjOrTransDenseMatrix , α:: Number , β:: Number )
126- A = transA . parent
127- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
128- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
129- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
130- nzv = nonzeros (A)
131- rv = rowvals (A)
132- if β != 1
133- β != 0 ? rmul! (C, β) : fill! (C, zero ( eltype (C)))
134- end
135- @sync for r in RangeIterator ( size (C, 2 ), Threads . nthreads ())
136- Threads . @spawn for k in r
137- @ inbounds for col = 1 : size (A, 2 )
111+ @eval function mul! (C:: StridedVecOrMat , xA :: $T {<:Any,<:ThreadedSparseMatrixCSC} , B:: StridedVector , α:: Number , β:: Number )
112+ A = xA . parent
113+ size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
114+ size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
115+ size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
116+ @assert size (B, 2 ) == 1
117+ nzv = nonzeros (A)
118+ rv = rowvals (A)
119+ if β != 1
120+ β != 0 ? rmul! (C, β) : fill! (C, zero ( eltype (C)))
121+ end
122+ @sync for r in RangeIterator ( size (A, 2 ), Threads . nthreads ())
123+ Threads . @spawn @ inbounds for col in r
138124 tmp = zero (eltype (C))
139- for j = getcolptr (A)[col] : ( getcolptr (A)[col + 1 ] - 1 )
140- tmp += transpose (nzv[j])* B[rv[j],k ]
125+ for j in nzrange (A, col )
126+ tmp += $ t (nzv[j])* B[rv[j]]
141127 end
142- C[col,k ] += tmp * α
128+ C[col] += tmp * α
143129 end
144130 end
131+ C
145132 end
146- C
147- end
148- function mul! (C:: StridedVecOrMat , transA:: Transpose{<:Any,<:ThreadedSparseMatrixCSC} , B:: StridedVector , α:: Number , β:: Number )
149- A = transA. parent
150- size (A, 2 ) == size (C, 1 ) || throw (DimensionMismatch ())
151- size (A, 1 ) == size (B, 1 ) || throw (DimensionMismatch ())
152- size (B, 2 ) == size (C, 2 ) || throw (DimensionMismatch ())
153- @assert size (B,2 )== 1
154- nzv = nonzeros (A)
155- rv = rowvals (A)
156- if β != 1
157- β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
158- end
159- @sync for r in RangeIterator (size (A,2 ), Threads. nthreads ())
160- Threads. @spawn @inbounds for col = r
161- tmp = zero (eltype (C))
162- for j = getcolptr (A)[col]: (getcolptr (A)[col + 1 ] - 1 )
163- tmp += transpose (nzv[j])* B[rv[j]]
164- end
165- C[col] += tmp * α
166- end
167- end
168- C
169133end
170134
171135function mul! (C:: StridedVecOrMat , X:: AdjOrTransDenseMatrix , A:: ThreadedSparseMatrixCSC , α:: Number , β:: Number )
@@ -178,18 +142,47 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransDenseMatrix, A::ThreadedSparseMat
178142 if β != 1
179143 β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
180144 end
145+ # TODO : split in X isa DenseMatrixUnion and X isa Adjoint/Transpose so we can use @simd in the first case (see original code in SparseArrays)
181146 @sync for r in RangeIterator (size (A,2 ), Threads. nthreads ())
182147 Threads. @spawn for col in r
183- @inbounds for k= getcolptr (A)[ col] : ( getcolptr (A)[col + 1 ] - 1 )
184- j = rv [k]
185- αv = nzv [k]* α
186- for multivec_row= 1 : mX
187- C[multivec_row, col] += X[multivec_row, j ] * αv
148+ @inbounds for k in nzrange (A, col)
149+ Aiα = nzv [k] * α
150+ rvk = rv [k]
151+ for multivec_row in 1 : mX
152+ C[multivec_row, col] += X[multivec_row, rvk ] * Aiα
188153 end
189154 end
190155 end
191156 end
192157 C
193158end
194159
160+ for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
161+ @eval function mul! (C:: StridedVecOrMat , X:: AdjOrTransDenseMatrix , xA:: $T{<:Any,<:ThreadedSparseMatrixCSC} , α:: Number , β:: Number )
162+ A = xA. parent
163+ mX, nX = size (X)
164+ nX == size (A, 2 ) || throw (DimensionMismatch ())
165+ mX == size (C, 1 ) || throw (DimensionMismatch ())
166+ size (A, 1 ) == size (C, 2 ) || throw (DimensionMismatch ())
167+ rv = rowvals (A)
168+ nzv = nonzeros (A)
169+ if β != 1
170+ β != 0 ? rmul! (C, β) : fill! (C, zero (eltype (C)))
171+ end
172+
173+ # transpose of Threaded * Dense algorithm above
174+ @sync for r in RangeIterator (size (C,1 ), Threads. nthreads ())
175+ Threads. @spawn for k in r
176+ @inbounds for col in 1 : size (A, 2 )
177+ αxj = X[k,col] * α
178+ for j in nzrange (A, col)
179+ C[k, rv[j]] += $ t (nzv[j])* αxj
180+ end
181+ end
182+ end
183+ end
184+ C
185+ end
186+ end
187+
195188end # module
0 commit comments