|
230 | 230 | # All structured matrices are square, and therefore they only broadcast out if they are size (1, 1) |
231 | 231 | Broadcast.newindex(D::StructuredMatrix, I::CartesianIndex{2}) = size(D) == (1,1) ? CartesianIndex(1,1) : I |
232 | 232 |
|
| 233 | +# Recursively replace wrapped matrices by their parents to improve broadcasting performance |
| 234 | +# We may do this because the indexing within `copyto!` is restricted to the stored indices |
| 235 | +preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) |
| 236 | +function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} |
| 237 | + args = map(x -> preprocess_broadcasted(T, x), bc.args) |
| 238 | + Broadcast.broadcasted(bc.f, args...) |
| 239 | +end |
| 240 | +# fallback case that doesn't unwrap at all |
| 241 | +_preprocess_broadcasted(::Type, x) = x |
| 242 | + |
| 243 | +_preprocess_broadcasted(::Type{Diagonal}, d::Diagonal) = d.diag |
| 244 | +# fallback for types that might opt into Diagonal-like structured broadcasting, e.g. wrappers |
| 245 | +_preprocess_broadcasted(::Type{Diagonal}, d::AbstractMatrix) = diagview(d) |
| 246 | + |
| 247 | +function copy(bc::Broadcasted{StructuredMatrixStyle{Diagonal}}) |
| 248 | + if isstructurepreserving(bc) || fzeropreserving(bc) |
| 249 | + # forward the broadcasting operation to the diagonal |
| 250 | + bc2 = preprocess_broadcasted(Diagonal, bc) |
| 251 | + return Diagonal(copy(bc2)) |
| 252 | + else |
| 253 | + @invoke copy(bc::Broadcasted) |
| 254 | + end |
| 255 | +end |
| 256 | + |
233 | 257 | function copyto!(dest::Diagonal, bc::Broadcasted{<:StructuredMatrixStyle}) |
234 | 258 | isvalidstructbc(dest, bc) || return copyto!(dest, convert(Broadcasted{Nothing}, bc)) |
235 | 259 | axs = axes(dest) |
@@ -291,13 +315,6 @@ function copyto!(dest::Tridiagonal, bc::Broadcasted{<:StructuredMatrixStyle}) |
291 | 315 | return dest |
292 | 316 | end |
293 | 317 |
|
294 | | -# Recursively replace wrapped matrices by their parents to improve broadcasting performance |
295 | | -# We may do this because the indexing within `copyto!` is restricted to the stored indices |
296 | | -preprocess_broadcasted(::Type{T}, A) where {T} = _preprocess_broadcasted(T, A) |
297 | | -function preprocess_broadcasted(::Type{T}, bc::Broadcasted) where {T} |
298 | | - args = map(x -> preprocess_broadcasted(T, x), bc.args) |
299 | | - Broadcast.Broadcasted(bc.f, args, bc.axes) |
300 | | -end |
301 | 318 | _preprocess_broadcasted(::Type{LowerTriangular}, A) = lowertridata(A) |
302 | 319 | _preprocess_broadcasted(::Type{UpperTriangular}, A) = uppertridata(A) |
303 | 320 | _preprocess_broadcasted(::Type{UpperHessenberg}, A) = upperhessenbergdata(A) |
|
0 commit comments