|
186 | 186 | function mul_result_structure(::SDiagonal, ::LowerTriangular{<:Any, <:StaticMatrix}) |
187 | 187 | return LowerTriangular |
188 | 188 | end |
| 189 | +function mul_result_structure(::UnitUpperTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) |
| 190 | + return UpperTriangular |
| 191 | +end |
| 192 | +function mul_result_structure(::UnitLowerTriangular{<:Any, <:StaticMatrix}, ::SDiagonal) |
| 193 | + return LowerTriangular |
| 194 | +end |
| 195 | +function mul_result_structure(::SDiagonal, ::UnitUpperTriangular{<:Any, <:StaticMatrix}) |
| 196 | + return UpperTriangular |
| 197 | +end |
| 198 | +function mul_result_structure(::SDiagonal, ::UnitLowerTriangular{<:Any, <:StaticMatrix}) |
| 199 | + return LowerTriangular |
| 200 | +end |
189 | 201 | function mul_result_structure(::SDiagonal, ::SDiagonal) |
190 | 202 | return Diagonal |
191 | 203 | end |
|
319 | 331 | if sa[2] != 0 |
320 | 332 | retexpr = gen_by_access(wrapped_a) do access_a |
321 | 333 | exprs = mul_smat_vec_exprs(sa, access_a) |
322 | | - return :(@inbounds similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) |
| 334 | + return :(@inbounds return similar_type(b, T, Size(sa[1]))(tuple($(exprs...)))) |
323 | 335 | end |
324 | 336 | else |
325 | 337 | exprs = [:(zero(T)) for k = 1:sa[1]] |
|
353 | 365 | end |
354 | 366 |
|
355 | 367 | _unstatic_array(::Type{TSA}) where {S, T, N, TSA<:StaticArray{S,T,N}} = AbstractArray{T,N} |
356 | | -for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular] |
| 368 | +for TWR in [Adjoint, Transpose, Symmetric, Hermitian, LowerTriangular, UpperTriangular, UnitUpperTriangular, UnitLowerTriangular, Diagonal] |
357 | 369 | @eval _unstatic_array(::Type{$TWR{T,TSA}}) where {S, T, N, TSA<:StaticArray{S,T,N}} = $TWR{T,<:AbstractArray{T,N}} |
358 | 370 | end |
359 | 371 |
|
|
378 | 390 |
|
379 | 391 | @generated function _mul(Sa::Size{sa}, Sb::Size{sb}, a::StaticMatMulLike{<:Any, <:Any, Ta}, b::StaticMatMulLike{<:Any, <:Any, Tb}) where {sa, sb, Ta, Tb} |
380 | 392 | # Heuristic choice for amount of codegen |
381 | | - if sa[1]*sa[2]*sb[2] <= 8*8*8 || a <: Diagonal || b <: Diagonal |
| 393 | + a_tri_mul = a <: LinearAlgebra.AbstractTriangular ? 2 : 1 |
| 394 | + b_tri_mul = b <: LinearAlgebra.AbstractTriangular ? 2 : 1 |
| 395 | + ab_tri_mul = (a == 2 && b == 2) ? 2 : 1 |
| 396 | + if sa[1]*sa[2]*sb[2] <= 8*8*8*a_tri_mul*b_tri_mul*ab_tri_mul || a <: Diagonal || b <: Diagonal |
382 | 397 | return quote |
383 | 398 | @_inline_meta |
384 | 399 | return mul_unrolled(Sa, Sb, a, b) |
|
0 commit comments