226226for op in (:+ , :- )
227227 @eval begin
228228 function Base. $op (a:: KroneckerArray , b:: KroneckerArray )
229+ iszero (a) && return $ op (b)
230+ iszero (b) && return a
229231 if a. b == b. b
230232 return $ op (a. a, b. a) ⊗ a. b
231233 elseif a. a == b. a
@@ -241,8 +243,15 @@ for op in (:+, :-)
241243 end
242244end
243245
244- using Base. Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
246+ # Allows for customizations for FillArrays.
247+ _BroadcastStyle (x) = BroadcastStyle (x)
248+
249+ using Base. Broadcast: Broadcast, AbstractArrayStyle, BroadcastStyle, Broadcasted
245250struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
251+ arg1 (:: Type{<:KroneckerStyle{<:Any,A}} ) where {A} = A
252+ arg1 (style:: KroneckerStyle ) = arg1 (typeof (style))
253+ arg2 (:: Type{<:KroneckerStyle{<:Any,B}} ) where {B} = B
254+ arg2 (style:: KroneckerStyle ) = arg2 (typeof (style))
246255function KroneckerStyle {N} (a:: BroadcastStyle , b:: BroadcastStyle ) where {N}
247256 return KroneckerStyle {N,a,b} ()
248257end
@@ -253,30 +262,69 @@ function KroneckerStyle{N,A,B}(v::Val{M}) where {N,A,B,M}
253262 return KroneckerStyle {M,typeof(A)(v),typeof(B)(v)} ()
254263end
255264function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A,B}} ) where {N,A,B}
256- return KroneckerStyle {N} (BroadcastStyle (A), BroadcastStyle (B))
265+ return KroneckerStyle {N} (_BroadcastStyle (A), _BroadcastStyle (B))
257266end
258267function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
259- return KroneckerStyle {N} (
260- BroadcastStyle (style1. a, style2. a), BroadcastStyle (style1. b, style2. b)
261- )
268+ style_a = BroadcastStyle (arg1 (style1), arg1 (style2))
269+ (style_a isa Broadcast. Unknown) && return Broadcast. Unknown ()
270+ style_b = BroadcastStyle (arg2 (style1), arg2 (style2))
271+ (style_b isa Broadcast. Unknown) && return Broadcast. Unknown ()
272+ return KroneckerStyle {N} (style_a, style_b)
262273end
263274function Base. similar (bc:: Broadcasted{<:KroneckerStyle{N,A,B}} , elt:: Type ) where {N,A,B}
264- ax_a = map (ax -> ax . product . a, axes (bc))
265- ax_b = map (ax -> ax . product . b, axes (bc))
275+ ax_a = arg1 .( axes (bc))
276+ ax_b = arg2 .( axes (bc))
266277 bc_a = Broadcasted (A, nothing , (), ax_a)
267278 bc_b = Broadcasted (B, nothing , (), ax_b)
268279 a = similar (bc_a, elt)
269280 b = similar (bc_b, elt)
270281 return a ⊗ b
271282end
283+ # Fallback definition of broadcasting falls back to `map` but assumes
284+ # inputs have been canonicalized to a map-compatible expression already,
285+ # for example by absorbing scalar arguments into the function.
272286function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:KroneckerStyle} )
273- return throw (
274- ArgumentError (
275- " Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
276- ),
277- )
287+ allequal (axes, bc. args) || throw (ArgumentError (" Broadcasted axes must be equal." ))
288+ map! (bc. f, dest, bc. args... )
289+ return dest
278290end
279291
292+ # Broadcast rewrite rules. Canonicalize inputs to absorb scalar inputs into the
293+ # function.
294+ function Base. broadcasted (style:: KroneckerStyle , :: typeof (* ), a:: Number , b:: KroneckerArray )
295+ return broadcasted (style, Base. Fix1 (* , a), b)
296+ end
297+ function Base. broadcasted (style:: KroneckerStyle , :: typeof (* ), a:: KroneckerArray , b:: Number )
298+ return broadcasted (style, Base. Fix2 (* , b), a)
299+ end
300+ function Base. broadcasted (style:: KroneckerStyle , :: typeof (/ ), a:: KroneckerArray , b:: Number )
301+ return broadcasted (style, Base. Fix2 (/ , b), a)
302+ end
303+ using MapBroadcast: MapBroadcast, MapFunction
304+ function Base. broadcasted (
305+ style:: KroneckerStyle ,
306+ f:: MapFunction{typeof(*),<:Tuple{<:Number,MapBroadcast.Arg}} ,
307+ a:: KroneckerArray ,
308+ )
309+ return broadcasted (style, Base. Fix1 (* , f. args[1 ]), a)
310+ end
311+ function Base. broadcasted (
312+ style:: KroneckerStyle ,
313+ f:: MapFunction{typeof(*),<:Tuple{MapBroadcast.Arg,<:Number}} ,
314+ a:: KroneckerArray ,
315+ )
316+ return broadcasted (style, Base. Fix2 (* , f. args[2 ]), a)
317+ end
318+ function Base. broadcasted (
319+ style:: KroneckerStyle ,
320+ f:: MapFunction{typeof(/),<:Tuple{MapBroadcast.Arg,<:Number}} ,
321+ a:: KroneckerArray ,
322+ )
323+ return broadcasted (style, Base. Fix2 (/ , f. args[2 ]), a)
324+ end
325+
326+ # TODO : Define by converting to a broadcast expession (with MapBroadcast.jl)
327+ # and then constructing the output with `similar`.
280328function Base. map (f, a1:: KroneckerArray , a_rest:: KroneckerArray... )
281329 return throw (
282330 ArgumentError (
@@ -312,6 +360,8 @@ for f in [:+, :-]
312360 function Base. map! (
313361 :: typeof ($ f), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray
314362 )
363+ iszero (b) && return map! (identity, dest, a)
364+ iszero (a) && return map! ($ f, dest, b)
315365 if a. b == b. b
316366 map! ($ f, dest. a, a. a, b. a)
317367 map! (identity, dest. b, a. b)
@@ -350,6 +400,15 @@ for op in [:*, :/]
350400 end
351401 end
352402end
403+ for f in [:+ , :- ]
404+ @eval begin
405+ function Base. map! (:: typeof ($ f), dest:: KroneckerArray , src:: KroneckerArray )
406+ map! ($ f, dest. a, src. a)
407+ map! (identity, dest. b, src. b)
408+ return dest
409+ end
410+ end
411+ end
353412
354413using DiagonalArrays: DiagonalArrays, diagonal
355414function DiagonalArrays. diagonal (a:: KroneckerArray )
0 commit comments