|
66 | 66 | function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticVecOrMat}}, asym = :wrapped_a) |
67 | 67 | return expr_gen(:adjoint) |
68 | 68 | end |
| 69 | +function gen_by_access(expr_gen, a::Type{<:SDiagonal}, asym = :wrapped_a) |
| 70 | + return expr_gen(:diagonal) |
| 71 | +end |
69 | 72 | """ |
70 | 73 | gen_by_access(expr_gen, a::Type{<:AbstractArray}, b::Type{<:AbstractArray}) |
71 | 74 |
|
@@ -148,6 +151,13 @@ function gen_by_access(expr_gen, a::Type{<:Adjoint{<:Any, <:StaticMatrix}}, b::T |
148 | 151 | end) |
149 | 152 | end |
150 | 153 | end |
| 154 | +function gen_by_access(expr_gen, a::Type{<:SDiagonal}, b::Type) |
| 155 | + return quote |
| 156 | + return $(gen_by_access(b, :wrapped_b) do access_b |
| 157 | + expr_gen(:diagonal, access_b) |
| 158 | + end) |
| 159 | + end |
| 160 | +end |
151 | 161 |
|
152 | 162 | """ |
153 | 163 | mul_result_structure(a::Type, b::Type) |
|
164 | 174 | function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::LowerTriangular{<:Any, <:StaticMatrix}) |
165 | 175 | return LowerTriangular |
166 | 176 | end |
| 177 | +function mul_result_structure(::UpperTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) |
| 178 | + return UpperTriangular |
| 179 | +end |
| 180 | +function mul_result_structure(::LowerTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) |
| 181 | + return LowerTriangular |
| 182 | +end |
| 183 | +function mul_result_structure(::SDiagonal, ::UpperTriangular{<:Any, <:StaticMatrix}) |
| 184 | + return UpperTriangular |
| 185 | +end |
| 186 | +function mul_result_structure(::SDiagonal, ::LowerTriangular{<:Any, <:StaticMatrix}) |
| 187 | + return LowerTriangular |
| 188 | +end |
| 189 | +function mul_result_structure(::SDiagonal, ::SDiagonal) |
| 190 | + return Diagonal |
| 191 | +end |
167 | 192 |
|
168 | 193 | """ |
169 | 194 | uplo_access(sa, asym, k, j, uplo) |
@@ -247,6 +272,12 @@ function uplo_access(sa, asym, k, j, uplo) |
247 | 272 | return :(transpose($asym[$(LinearIndices(reverse(sa))[j, k])])) |
248 | 273 | elseif uplo == :adjoint |
249 | 274 | return :(adjoint($asym[$(LinearIndices(reverse(sa))[j, k])])) |
| 275 | + elseif uplo == :diagonal |
| 276 | + if k == j |
| 277 | + return :($asym[$k]) |
| 278 | + else |
| 279 | + return :(zero($TAsym)) |
| 280 | + end |
250 | 281 | else |
251 | 282 | error("Unknown uplo: $uplo") |
252 | 283 | end |
@@ -347,12 +378,12 @@ end |
347 | 378 |
|
348 | 379 | @generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} |
349 | 380 | # Heuristic choice for amount of codegen |
350 | | - if sa[1]*sa[2]*sb[2] <= 8*8*8 || !(a <: StaticMatrix) || !(b <: StaticMatrix) |
| 381 | + if sa[1]*sa[2]*sb[2] <= 8*8*8 || a <: Diagonal || b <: Diagonal |
351 | 382 | return quote |
352 | 383 | @_inline_meta |
353 | 384 | return mul_unrolled(Sa, Sb, a, b) |
354 | 385 | end |
355 | | - elseif sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14 |
| 386 | + elseif (sa[1] <= 14 && sa[2] <= 14 && sb[2] <= 14) || !(a <: StaticMatrix) || !(b <: StaticMatrix) |
356 | 387 | return quote |
357 | 388 | @_inline_meta |
358 | 389 | return mul_unrolled_chunks(Sa, Sb, a, b) |
|
436 | 467 | tmp_type_out = :(SVector{$(sa[1]), T}) |
437 | 468 |
|
438 | 469 | retexpr = gen_by_access(wrapped_a, wrapped_b) do access_a, access_b |
439 | | - vect_exprs = [:($(Symbol("tmp_$k2"))::$tmp_type_out = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()), |
| 470 | + vect_exprs = [:($(Symbol("tmp_$k2")) = partly_unrolled_multiply($(Size{sa}()), $(Size{(sb[1],)}()), |
440 | 471 | a, $(Expr(:call, tmp_type_in, [uplo_access(sb, :b, i, k2, access_b) for i = 1:sb[1]]...)), $(Val(access_a)))::$tmp_type_out) for k2 = 1:sb[2]] |
441 | 472 |
|
442 | 473 | exprs = [:($(Symbol("tmp_$k2"))[$k1]) for k1 = 1:sa[1], k2 = 1:sb[2]] |
|
0 commit comments