1- struct BlockMap{T,As<: Tuple{Vararg{LinearMap}} ,Rs<: Tuple{Vararg{Int}} } <: LinearMap{T}
1+ struct BlockMap{T,As<: Tuple{Vararg{LinearMap}} ,Rs<: Tuple{Vararg{Int}} ,Rranges <: Tuple{Vararg{UnitRange{Int}}} ,Cranges <: Tuple{Vararg{UnitRange{Int}}} } <: LinearMap{T}
22 maps:: As
33 rows:: Rs
4- rowranges:: Vector{UnitRange{Int}}
5- colranges:: Vector{UnitRange{Int}}
4+ rowranges:: Rranges
5+ colranges:: Cranges
66 function BlockMap {T,R,S} (maps:: R , rows:: S ) where {T, R<: Tuple{Vararg{LinearMap}} , S<: Tuple{Vararg{Int}} }
77 for A in maps
88 promote_type (T, eltype (A)) == T || throw (InexactError ())
99 end
1010 rowranges, colranges = rowcolranges (maps, rows)
11- return new {T,R,S} (maps, rows, rowranges, colranges)
11+ return new {T,R,S,typeof(rowranges),typeof(colranges) } (maps, rows, rowranges, colranges)
1212 end
1313end
1414
@@ -28,28 +28,28 @@ Determines the range of rows for each block row and the range of columns for eac
2828map in `maps`, according to its position in a virtual matrix representation of the
2929block linear map obtained from `hvcat(rows, maps...)`.
3030"""
31- function rowcolranges (maps, rows):: Tuple{Vector{UnitRange{Int}},Vector{UnitRange{Int}}}
32- rowranges = Vector {UnitRange{Int}} (undef, length (rows) )
33- colranges = Vector {UnitRange{Int}} (undef, length (maps) )
31+ function rowcolranges (maps, rows)
32+ rowranges = ( )
33+ colranges = ( )
3434 mapind = 0
3535 rowstart = 1
36- for rowind in 1 : length ( rows)
37- xinds = vcat (1 , map (a -> size (a, 2 ), maps[mapind+ 1 : mapind+ rows[rowind] ])... )
36+ for row in rows
37+ xinds = vcat (1 , map (a -> size (a, 2 ), maps[mapind+ 1 : mapind+ row ])... )
3838 cumsum! (xinds, xinds)
3939 mapind += 1
4040 rowend = rowstart + size (maps[mapind], 1 ) - 1
41- rowranges[rowind] = rowstart: rowend
42- colranges[mapind] = xinds[1 ]: xinds[2 ]- 1
43- for colind in 2 : rows[rowind]
41+ rowranges = (rowranges ... , rowstart: rowend)
42+ colranges = (colranges ... , xinds[1 ]: xinds[2 ]- 1 )
43+ for colind in 2 : row
4444 mapind += 1
45- colranges[mapind] = xinds[colind]: xinds[colind+ 1 ]- 1
45+ colranges = (colranges ... , xinds[colind]: xinds[colind+ 1 ]- 1 )
4646 end
4747 rowstart = rowend + 1
4848 end
49- return rowranges, colranges
49+ return rowranges:: NTuple{length(rows), UnitRange{Int}} , colranges:: NTuple{length(maps), UnitRange{Int}}
5050end
5151
52- Base. size (A:: BlockMap ) = (last (A. rowranges[ end ]) , last (A. colranges[ end ] ))
52+ Base. size (A:: BlockMap ) = (last (last ( A. rowranges)) , last (last ( A. colranges) ))
5353
5454# ###########
5555# concatenation
@@ -299,75 +299,82 @@ LinearAlgebra.transpose(A::BlockMap) = TransposeMap(A)
299299LinearAlgebra. adjoint (A:: BlockMap ) = AdjointMap (A)
300300
301301# ###########
302- # multiplication with vectors
302+ # multiplication helper functions
303303# ###########
304304
305- function A_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector )
306- require_one_based_indexing (y, x)
307- m, n = size (A)
308- @boundscheck (m == length (y) && n == length (x)) || throw (DimensionMismatch (" A_mul_B!" ))
305+ @inline function _blockmul! (y, A:: BlockMap , x, α, β)
309306 maps, rows, yinds, xinds = A. maps, A. rows, A. rowranges, A. colranges
310307 mapind = 0
311- @views @inbounds for rowind in 1 : length (rows)
312- yrow = y[yinds[rowind]]
308+ @views @inbounds for (row, yi) in zip (rows, yinds )
309+ yrow = selectdim (y, 1 , yi)
313310 mapind += 1
314- A_mul_B ! (yrow, maps[mapind], x[ xinds[mapind]] )
315- for colind in 2 : rows[rowind]
311+ mul ! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α, β )
312+ for _ in 2 : row
316313 mapind += 1
317- mul! (yrow, maps[mapind], x[ xinds[mapind]], true , true )
314+ mul! (yrow, maps[mapind], selectdim (x, 1 , xinds[mapind]), α , true )
318315 end
319316 end
320317 return y
321318end
322319
323- function At_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector )
324- require_one_based_indexing (y, x)
325- m, n = size (A)
326- @boundscheck (n == length (y) && m == length (x)) || throw (DimensionMismatch (" At_mul_B!" ))
320+ @inline function _transblockmul! (y, A:: BlockMap , x, α, β, transform)
327321 maps, rows, xinds, yinds = A. maps, A. rows, A. rowranges, A. colranges
328- mapind = 0
329- # first block row (rowind = 1) of A, meaning first block column of A', fill all of y
330322 @views @inbounds begin
331- xcol = x[xinds[ 1 ]]
332- for colind in 1 : rows[ 1 ]
333- mapind += 1
334- A_mul_B! (y[ yinds[mapind]], transpose (maps[mapind ]), xcol)
323+ # first block row (rowind = 1) of A, meaning first block column of A', fill all of y
324+ xcol = selectdim (x, 1 , first (xinds))
325+ for rowind in 1 : first (rows)
326+ mul! ( selectdim (y, 1 , yinds[rowind]), transform (maps[rowind ]), xcol, α, β )
335327 end
336- # subsequent block rows of A, add results to corresponding parts of y
337- for rowind in 2 : length (rows)
338- xcol = x[xinds[rowind]]
339- for colind in 1 : rows[rowind]
328+ mapind = first (rows)
329+ # subsequent block rows of A (block columns of A'),
330+ # add results to corresponding parts of y
331+ # TODO : think about multithreading
332+ for (row, xi) in zip (Base. tail (rows), Base. tail (xinds))
333+ xcol = selectdim (x, 1 , xi)
334+ for _ in 1 : row
340335 mapind += 1
341- mul! (y[ yinds[mapind]], transpose (maps[mapind]), xcol, true , true )
336+ mul! (selectdim (y, 1 , yinds[mapind]), transform (maps[mapind]), xcol, α , true )
342337 end
343338 end
344339 end
345340 return y
346341end
347342
348- function Ac_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector )
349- require_one_based_indexing (y, x)
350- m, n = size (A)
351- @boundscheck (n == length (y) && m == length (x)) || throw (DimensionMismatch (" At_mul_B!" ))
352- maps, rows, xinds, yinds = A. maps, A. rows, A. rowranges, A. colranges
353- mapind = 0
354- # first block row (rowind = 1) of A, fill all of y
355- @views @inbounds begin
356- xcol = x[xinds[1 ]]
357- for colind in 1 : rows[1 ]
358- mapind += 1
359- A_mul_B! (y[yinds[mapind]], adjoint (maps[mapind]), xcol)
360- end
361- # subsequent block rows of A, add results to corresponding parts of y
362- for rowind in 2 : length (rows)
363- xcol = x[xinds[rowind]]
364- for colind in 1 : rows[rowind]
365- mapind += 1
366- mul! (y[yinds[mapind]], adjoint (maps[mapind]), xcol, true , true )
367- end
343+ # ###########
344+ # multiplication with vectors & matrices
345+ # ###########
346+
347+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
348+ mul! (y, A, x)
349+
350+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: TransposeMap{<:Any,<:BlockMap} , x:: AbstractVector ) =
351+ mul! (y, A, x)
352+
353+ Base. @propagate_inbounds At_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
354+ mul! (y, transpose (A), x)
355+
356+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: AdjointMap{<:Any,<:BlockMap} , x:: AbstractVector ) =
357+ mul! (y, A, x)
358+
359+ Base. @propagate_inbounds Ac_mul_B! (y:: AbstractVector , A:: BlockMap , x:: AbstractVector ) =
360+ mul! (y, adjoint (A), x)
361+
362+ for Atype in (AbstractVector, AbstractMatrix)
363+ @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , A:: BlockMap , x:: $Atype ,
364+ α:: Number = true , β:: Number = false )
365+ require_one_based_indexing (y, x)
366+ @boundscheck check_dim_mul (y, A, x)
367+ return _blockmul! (y, A, x, α, β)
368+ end
369+
370+ for (maptype, transform) in ((:(TransposeMap{<: Any ,<: BlockMap }), :transpose ), (:(AdjointMap{<: Any ,<: BlockMap }), :adjoint ))
371+ @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , wrapA:: $maptype , x:: $Atype ,
372+ α:: Number = true , β:: Number = false )
373+ require_one_based_indexing (y, x)
374+ @boundscheck check_dim_mul (y, wrapA, x)
375+ return _transblockmul! (y, wrapA. lmap, x, α, β, $ transform)
368376 end
369377 end
370- return y
371378end
372379
373380# ###########
388395# show(io, T)
389396# print(io, '}')
390397# end
398+
399+ # ###########
400+ # BlockDiagonalMap
401+ # ###########
402+
403+ struct BlockDiagonalMap{T,As<: Tuple{Vararg{LinearMap}} ,Ranges<: Tuple{Vararg{UnitRange{Int}}} } <: LinearMap{T}
404+ maps:: As
405+ rowranges:: Ranges
406+ colranges:: Ranges
407+ function BlockDiagonalMap {T,As} (maps:: As ) where {T, As<: Tuple{Vararg{LinearMap}} }
408+ for A in maps
409+ promote_type (T, eltype (A)) == T || throw (InexactError ())
410+ end
411+ # row ranges
412+ inds = vcat (1 , size .(maps, 1 )... )
413+ cumsum! (inds, inds)
414+ rowranges = ntuple (i -> inds[i]: inds[i+ 1 ]- 1 , Val (length (maps)))
415+ # column ranges
416+ inds[2 : end ] .= size .(maps, 2 )
417+ cumsum! (inds, inds)
418+ colranges = ntuple (i -> inds[i]: inds[i+ 1 ]- 1 , Val (length (maps)))
419+ return new {T,As,typeof(rowranges)} (maps, rowranges, colranges)
420+ end
421+ end
422+
423+ BlockDiagonalMap {T} (maps:: As ) where {T,As<: Tuple{Vararg{LinearMap}} } =
424+ BlockDiagonalMap {T,As} (maps)
425+ BlockDiagonalMap (maps:: LinearMap... ) =
426+ BlockDiagonalMap {promote_type(map(eltype, maps)...)} (maps)
427+
428+ for k in 1 : 8 # is 8 sufficient?
429+ Is = ntuple (n-> :($ (Symbol (:A ,n)):: AbstractMatrix ), Val (k- 1 ))
430+ # yields (:A1, :A2, :A3, ..., :A(k-1))
431+ L = :($ (Symbol (:A ,k)):: LinearMap )
432+ # yields :Ak
433+ mapargs = ntuple (n -> :(LinearMap ($ (Symbol (:A ,n)))), Val (k- 1 ))
434+ # yields (:LinearMap(A1), :LinearMap(A2), ..., :LinearMap(A(k-1)))
435+
436+ @eval begin
437+ SparseArrays. blockdiag ($ (Is... ), $ L, As:: Union{LinearMap,AbstractMatrix} ...) =
438+ BlockDiagonalMap ($ (mapargs... ), $ (Symbol (:A ,k)), convert_to_lmaps (As... )... )
439+ function Base. cat ($ (Is... ), $ L, As:: Union{LinearMap,AbstractMatrix} ...; dims:: Dims{2} )
440+ if dims == (1 ,2 )
441+ return BlockDiagonalMap ($ (mapargs... ), $ (Symbol (:A ,k)), convert_to_lmaps (As... )... )
442+ else
443+ throw (ArgumentError (" dims keyword in cat of LinearMaps must be (1,2)" ))
444+ end
445+ end
446+ end
447+ end
448+
449+ Base. size (A:: BlockDiagonalMap ) = (last (A. rowranges[end ]), last (A. colranges[end ]))
450+
451+ LinearAlgebra. issymmetric (A:: BlockDiagonalMap ) = all (issymmetric, A. maps)
452+ LinearAlgebra. ishermitian (A:: BlockDiagonalMap{<:Real} ) = all (issymmetric, A. maps)
453+ LinearAlgebra. ishermitian (A:: BlockDiagonalMap ) = all (ishermitian, A. maps)
454+
455+ LinearAlgebra. adjoint (A:: BlockDiagonalMap{T} ) where {T} = BlockDiagonalMap {T} (map (adjoint, A. maps))
456+ LinearAlgebra. transpose (A:: BlockDiagonalMap{T} ) where {T} = BlockDiagonalMap {T} (map (transpose, A. maps))
457+
458+ Base.:(== )(A:: BlockDiagonalMap , B:: BlockDiagonalMap ) = (eltype (A) == eltype (B) && A. maps == B. maps)
459+
460+ Base. @propagate_inbounds A_mul_B! (y:: AbstractVector , A:: BlockDiagonalMap , x:: AbstractVector ) =
461+ mul! (y, A, x, true , false )
462+
463+ Base. @propagate_inbounds At_mul_B! (y:: AbstractVector , A:: BlockDiagonalMap , x:: AbstractVector ) =
464+ mul! (y, transpose (A), x, true , false )
465+
466+ Base. @propagate_inbounds Ac_mul_B! (y:: AbstractVector , A:: BlockDiagonalMap , x:: AbstractVector ) =
467+ mul! (y, adjoint (A), x, true , false )
468+
469+ for Atype in (AbstractVector, AbstractMatrix)
470+ @eval Base. @propagate_inbounds function LinearAlgebra. mul! (y:: $Atype , A:: BlockDiagonalMap , x:: $Atype ,
471+ α:: Number = true , β:: Number = false )
472+ require_one_based_indexing (y, x)
473+ @boundscheck check_dim_mul (y, A, x)
474+ return _blockscaling! (y, A, x, α, β)
475+ end
476+ end
477+
478+ @inline function _blockscaling! (y, A:: BlockDiagonalMap , x, α, β)
479+ maps, yinds, xinds = A. maps, A. rowranges, A. colranges
480+ # TODO : think about multi-threading here
481+ @views @inbounds for i in eachindex (yinds, maps, xinds)
482+ mul! (selectdim (y, 1 , yinds[i]), maps[i], selectdim (x, 1 , xinds[i]), α, β)
483+ end
484+ return y
485+ end
0 commit comments