@@ -303,9 +303,21 @@ LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A)
303303# ###########
304304
305305@inline function _blockmul! (y, A:: BlockMap , x, α, β)
306+ if iszero (α)
307+ iszero (β) && return fill! (y, zero (eltype (y)))
308+ isone (β) && return y
309+ return rmul! (y, β)
310+ end
311+ return __blockmul! (MulStyle (A), y, A, x, α, β)
312+ end
313+
314+ @inline __blockmul! (:: FiveArg , y, A, x, α, β) = ___blockmul! (y, A, x, α, β, nothing )
315+ @inline __blockmul! (:: ThreeArg , y, A, x, α, β) = ___blockmul! (y, A, x, α, β, similar (y))
316+
317+ function ___blockmul! (y, A, x, α, β, :: Nothing )
306318 maps, rows, yinds, xinds = A. maps, A. rows, A. rowranges, A. colranges
307319 mapind = 0
308- @views @inbounds for (row, yi) in zip (rows, yinds)
320+ @views for (row, yi) in zip (rows, yinds)
309321 yrow = selectdim (y, 1 , yi)
310322 mapind += 1
311323 mul! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, β)
@@ -316,24 +328,50 @@ LinearAlgebra.adjoint(A::BlockMap) = AdjointMap(A)
316328 end
317329 return y
318330end
331+ function ___blockmul! (y, A, x, α, β, z)
332+ maps, rows, yinds, xinds = A. maps, A. rows, A. rowranges, A. colranges
333+ mapind = 0
334+ @views for (row, yi) in zip (rows, yinds)
335+ yrow = selectdim (y, 1 , yi)
336+ zrow = selectdim (z, 1 , yi)
337+ mapind += 1
338+ if MulStyle (maps[mapind]) === ThreeArg () && ! iszero (β)
339+ ! isone (β) && rmul! (yrow, β)
340+ muladd! (ThreeArg (), yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, zrow)
341+ else
342+ mul! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, β)
343+ end
344+ for _ in 2 : row
345+ mapind += 1
346+ muladd! (MulStyle (maps[mapind]), yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, zrow)
347+ end
348+ end
349+ return y
350+ end
319351
320352@inline function _transblockmul! (y, A:: BlockMap , x, α, β, transform)
321353 maps, rows, xinds, yinds = A. maps, A. rows, A. rowranges, A. colranges
322- @views @inbounds begin
323- # first block row (rowind = 1) of A, meaning first block column of A', fill all of y
324- xcol = selectdim (x, 1 , first (xinds))
325- for rowind in 1 : first (rows)
326- mul! (selectdim (y, 1 , yinds[rowind]), transform (maps[rowind]), xcol, α, β)
327- end
328- mapind = first (rows)
329- # subsequent block rows of A (block columns of A'),
330- # add results to corresponding parts of y
331- # TODO : think about multithreading
332- for (row, xi) in zip (Base. tail (rows), Base. tail (xinds))
333- xcol = selectdim (x, 1 , xi)
334- for _ in 1 : row
335- mapind += 1
336- mul! (selectdim (y, 1 , yinds[mapind]), transform (maps[mapind]), xcol, α, true )
354+ if iszero (α)
355+ iszero (β) && return fill! (y, zero (eltype (y)))
356+ isone (β) && return y
357+ return rmul! (y, β)
358+ else
359+ @views begin
360+ # first block row (rowind = 1) of A, meaning first block column of A', fill all of y
361+ xcol = selectdim (x, 1 , first (xinds))
362+ for rowind in 1 : first (rows)
363+ mul! (selectdim (y, 1 , yinds[rowind]), transform (maps[rowind]), xcol, α, β)
364+ end
365+ mapind = first (rows)
366+ # subsequent block rows of A (block columns of A'),
367+ # add results to corresponding parts of y
368+ # TODO : think about multithreading
369+ for (row, xi) in zip (Base. tail (rows), Base. tail (xinds))
370+ xcol = selectdim (x, 1 , xi)
371+ for _ in 1 : row
372+ mapind += 1
373+ mul! (selectdim (y, 1 , yinds[mapind]), transform (maps[mapind]), xcol, α, true )
374+ end
337375 end
338376 end
339377 end
@@ -345,34 +383,34 @@ end
345383# ###########
346384
347385Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
348- mul! (y, A, x)
386+ mul! (y, A, x, true , false )
349387
350388Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: TransposeMap{<:Any,<:BlockMap} , x:: AbstractVector ) =
351- mul! (y, A, x)
389+ mul! (y, A, x, true , false )
352390
353391Base. @propagate_inbounds At_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
354- mul! (y, transpose (A), x)
392+ mul! (y, transpose (A), x, true , false )
355393
356394Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: AdjointMap{<:Any,<:BlockMap} , x:: AbstractVector ) =
357- mul! (y, A, x)
395+ mul! (y, A, x, true , false )
358396
359397Base. @propagate_inbounds Ac_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
360- mul! (y, adjoint (A), x)
398+ mul! (y, adjoint (A), x, true , false )
361399
362400for Atype in (AbstractVector, AbstractMatrix)
363401 @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , A:: BlockMap , x:: $Atype ,
364- α:: Number = true , β:: Number = false )
402+ α:: Number , β:: Number )
365403 require_one_based_indexing (y, x)
366404 @boundscheck check_dim_mul (y, A, x)
367- return _blockmul! (y, A, x, α, β)
405+ return @inbounds _blockmul! (y, A, x, α, β)
368406 end
369407
370408 for (maptype, transform) in ((:(TransposeMap{<: Any ,<: BlockMap }), :transpose ), (:(AdjointMap{<: Any ,<: BlockMap }), :adjoint ))
371409 @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , wrapA:: $maptype , x:: $Atype ,
372- α:: Number = true , β:: Number = false )
410+ α:: Number , β:: Number )
373411 require_one_based_indexing (y, x)
374412 @boundscheck check_dim_mul (y, wrapA, x)
375- return _transblockmul! (y, wrapA. lmap, x, α, β, $ transform)
413+ return @inbounds _transblockmul! (y, wrapA. lmap, x, α, β, $ transform)
376414 end
377415 end
378416end
@@ -468,7 +506,7 @@ Base.@propagate_inbounds Ac_mul_B!(y::AbstractVector, A::BlockDiagonalMap, x::Ab
468506
469507for Atype in (AbstractVector, AbstractMatrix)
470508 @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , A:: BlockDiagonalMap , x:: $Atype ,
471- α:: Number = true , β:: Number = false )
509+ α:: Number , β:: Number )
472510 require_one_based_indexing (y, x)
473511 @boundscheck check_dim_mul (y, A, x)
474512 return _blockscaling! (y, A, x, α, β)
0 commit comments