diff --git a/src/abstract.jl b/src/abstract.jl index c3940a3..2f33b0f 100644 --- a/src/abstract.jl +++ b/src/abstract.jl @@ -179,6 +179,23 @@ storage_type(op::Adjoint) = storage_type(parent(op)) storage_type(op::Transpose) = storage_type(parent(op)) storage_type(op::Diagonal) = typeof(parent(op)) +@inline function _select_storage_type( + op1::AbstractLinearOperator, + op2::AbstractLinearOperator, + T::Type, +) + S = promote_type(storage_type(op1), storage_type(op2)) + if isconcretetype(S) + return S + elseif isconcretetype(storage_type(op1)) + return storage_type(op1) + elseif isconcretetype(storage_type(op2)) + return storage_type(op2) + else + return Vector{T} + end +end + """ reset!(op) diff --git a/src/cat.jl b/src/cat.jl index 764c789..b20e1be 100644 --- a/src/cat.jl +++ b/src/cat.jl @@ -45,9 +45,7 @@ function hcat(A::AbstractLinearOperator, B::AbstractLinearOperator) ctprod! = @closure (res, w, α, β) -> hcat_ctprod!(res, adjoint(A), adjoint(B), Ancol, Ancol + Bncol, w, α, β) args5 = (has_args5(A) && has_args5(B)) - S = promote_type(storage_type(A), storage_type(B)) - isconcretetype(S) || - throw(LinearOperatorException("storage types cannot be promoted to a concrete type")) + S = _select_storage_type(A, B, T) CompositeLinearOperator(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, args5, S = S) end @@ -104,9 +102,7 @@ function vcat(A::AbstractLinearOperator, B::AbstractLinearOperator) ctprod! = @closure (res, w, α, β) -> vcat_ctprod!(res, adjoint(A), adjoint(B), Anrow, Anrow + Bnrow, w, α, β) args5 = (has_args5(A) && has_args5(B)) - S = promote_type(storage_type(A), storage_type(B)) - isconcretetype(S) || - throw(LinearOperatorException("storage types cannot be promoted to a concrete type")) + S = _select_storage_type(A, B, T) CompositeLinearOperator(T, nrow, ncol, false, false, prod!, tprod!, ctprod!, args5, S = S) end diff --git a/src/operations.jl b/src/operations.jl index 18259d4..b02803d 100644 --- a/src/operations.jl +++ b/src/operations.jl @@ -139,9 +139,7 @@ function *(op1::AbstractLinearOperator, op2::AbstractLinearOperator) if m2 != n1 throw(LinearOperatorException("shape mismatch")) end - S = promote_type(storage_type(op1), storage_type(op2)) - isconcretetype(S) || - throw(LinearOperatorException("storage types cannot be promoted to a concrete type")) + S = _select_storage_type(op1, op2, T) #tmp vector for products vtmp = fill!(S(undef, m2), zero(T)) utmp = fill!(S(undef, n1), zero(T)) @@ -210,9 +208,7 @@ function +(op1::AbstractLinearOperator, op2::AbstractLinearOperator) symm = (issymmetric(op1) && issymmetric(op2)) herm = (ishermitian(op1) && ishermitian(op2)) args5 = (has_args5(op1) && has_args5(op2)) - S = promote_type(storage_type(op1), storage_type(op2)) - isconcretetype(S) || - throw(LinearOperatorException("storage types cannot be promoted to a concrete type")) + S = _select_storage_type(op1, op2, T) return CompositeLinearOperator(T, m1, n1, symm, herm, prod!, tprod!, ctprod!, args5, S = S) end