250250function Base. iszero (a:: KroneckerArray )
251251 return iszero (a. a) || iszero (a. b)
252252end
253+ function Base. isreal (a:: KroneckerArray )
254+ return isreal (a. a) && isreal (a. b)
255+ end
253256function Base. inv (a:: KroneckerArray )
254257 return inv (a. a) ⊗ inv (a. b)
255258end
270273function Base.:* (a:: KroneckerArray , b:: Number )
271274 return a. a ⊗ (a. b * b)
272275end
276+ function Base.:/ (a:: KroneckerArray , b:: Number )
277+ return a * inv (b)
278+ end
273279
274280function Base.:- (a:: KroneckerArray )
275281 return (- a. a) ⊗ a. b
@@ -291,26 +297,82 @@ for op in (:+, :-)
291297 end
292298end
293299
300+ using Base. Broadcast: AbstractArrayStyle, BroadcastStyle, Broadcasted
301+ struct KroneckerStyle{N,A,B} <: AbstractArrayStyle{N} end
302+ function KroneckerStyle {N} (a:: BroadcastStyle , b:: BroadcastStyle ) where {N}
303+ return KroneckerStyle {N,a,b} ()
304+ end
305+ function KroneckerStyle (a:: AbstractArrayStyle{N} , b:: AbstractArrayStyle{N} ) where {N}
306+ return KroneckerStyle {N} (a, b)
307+ end
308+ function KroneckerStyle {N,A,B} (v:: Val{M} ) where {N,A,B,M}
309+ return KroneckerStyle {M,typeof(A)(v),typeof(B)(v)} ()
310+ end
311+ function Base. BroadcastStyle (:: Type{<:KroneckerArray{<:Any,N,A,B}} ) where {N,A,B}
312+ return KroneckerStyle {N} (BroadcastStyle (A), BroadcastStyle (B))
313+ end
314+ function Base. BroadcastStyle (style1:: KroneckerStyle{N} , style2:: KroneckerStyle{N} ) where {N}
315+ return KroneckerStyle {N} (
316+ BroadcastStyle (style1. a, style2. a), BroadcastStyle (style1. b, style2. b)
317+ )
318+ end
319+ function Base. similar (bc:: Broadcasted{<:KroneckerStyle{N,A,B}} , elt:: Type ) where {N,A,B}
320+ ax_a = map (ax -> ax. product. a, axes (bc))
321+ ax_b = map (ax -> ax. product. b, axes (bc))
322+ bc_a = Broadcasted (A, nothing , (), ax_a)
323+ bc_b = Broadcasted (B, nothing , (), ax_b)
324+ a = similar (bc_a, elt)
325+ b = similar (bc_b, elt)
326+ return a ⊗ b
327+ end
328+ function Base. copyto! (dest:: AbstractArray , bc:: Broadcasted{<:KroneckerStyle} )
329+ return throw (
330+ ArgumentError (
331+ " Arbitrary broadcasting is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
332+ ),
333+ )
334+ end
335+
336+ function Base. map (f, a1:: KroneckerArray , a_rest:: KroneckerArray... )
337+ return throw (
338+ ArgumentError (
339+ " Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
340+ ),
341+ )
342+ end
343+ function Base. map! (f, dest:: KroneckerArray , a1:: KroneckerArray , a_rest:: KroneckerArray... )
344+ return throw (
345+ ArgumentError (
346+ " Arbitrary mapping is not supported for KroneckerArrays since they might not preserve the Kronecker structure." ,
347+ ),
348+ )
349+ end
294350function Base. map! (:: typeof (identity), dest:: KroneckerArray , a:: KroneckerArray )
295351 dest. a .= a. a
296352 dest. b .= a. b
297353 return dest
298354end
299- function Base. map! (:: typeof (+ ), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray )
300- if a. b == b. b
301- map! (+ , dest. a, a. a, b. a)
302- dest. b .= a. b
303- elseif a. a == b. a
304- dest. a .= a. a
305- map! (+ , dest. b, a. b, b. b)
306- else
307- throw (
308- ArgumentError (
309- " KroneckerArray addition is only supported when the first or second arguments match." ,
310- ),
355+ for f in [:+ , :- ]
356+ @eval begin
357+ function Base. map! (
358+ :: typeof ($ f), dest:: KroneckerArray , a:: KroneckerArray , b:: KroneckerArray
311359 )
360+ if a. b == b. b
361+ map! ($ f, dest. a, a. a, b. a)
362+ dest. b .= a. b
363+ elseif a. a == b. a
364+ dest. a .= a. a
365+ map! ($ f, dest. b, a. b, b. b)
366+ else
367+ throw (
368+ ArgumentError (
369+ " KroneckerArray addition is only supported when the first or second arguments match." ,
370+ ),
371+ )
372+ end
373+ return dest
374+ end
312375 end
313- return dest
314376end
315377function Base. map! (
316378 f:: Base.Fix1{typeof(*),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
@@ -326,6 +388,16 @@ function Base.map!(
326388 dest. b .= f. f .(a. b, f. x)
327389 return dest
328390end
391+ function Base. map! (
392+ f:: Base.Fix2{typeof(/),<:Number} , dest:: KroneckerArray , a:: KroneckerArray
393+ )
394+ return map! (Base. Fix2 (* , inv (f. x)), dest, a)
395+ end
396+ function Base. map! (:: typeof (conj), dest:: KroneckerArray , a:: KroneckerArray )
397+ dest. a .= conj .(a. a)
398+ dest. b .= conj .(a. b)
399+ return dest
400+ end
329401
330402using LinearAlgebra:
331403 LinearAlgebra,
@@ -343,9 +415,10 @@ using LinearAlgebra:
343415 svd,
344416 svdvals,
345417 tr
346- diagonal (a:: AbstractArray ) = Diagonal (a)
347- function diagonal (a:: KroneckerArray )
348- return Diagonal (a. a) ⊗ Diagonal (a. b)
418+
419+ using DiagonalArrays: DiagonalArrays, diagonal
420+ function DiagonalArrays. diagonal (a:: KroneckerArray )
421+ return diagonal (a. a) ⊗ diagonal (a. b)
349422end
350423
351424function Base.:* (a:: KroneckerArray , b:: KroneckerArray )
@@ -372,6 +445,23 @@ function LinearAlgebra.norm(a::KroneckerArray, p::Int=2)
372445 return norm (a. a, p) ⊗ norm (a. b, p)
373446end
374447
448+ function Base. real (a:: KroneckerArray )
449+ if iszero (imag (a. a)) || iszero (imag (a. b))
450+ return real (a. a) ⊗ real (a. b)
451+ elseif iszero (real (a. a)) || iszero (real (a. b))
452+ return - imag (a. a) ⊗ imag (a. b)
453+ end
454+ return real (a. a) ⊗ real (a. b) - imag (a. a) ⊗ imag (a. b)
455+ end
456+ function Base. imag (a:: KroneckerArray )
457+ if iszero (imag (a. a)) || iszero (real (a. b))
458+ return real (a. a) ⊗ imag (a. b)
459+ elseif iszero (real (a. a)) || iszero (imag (a. b))
460+ return imag (a. a) ⊗ real (a. b)
461+ end
462+ return real (a. a) ⊗ imag (a. b) + imag (a. a) ⊗ real (a. b)
463+ end
464+
375465using MatrixAlgebraKit: MatrixAlgebraKit, diagview
376466function MatrixAlgebraKit. diagview (a:: KroneckerMatrix )
377467 return diagview (a. a) ⊗ diagview (a. b)
@@ -506,6 +596,19 @@ const EyeKronecker{T,A<:Eye{T},B<:AbstractMatrix{T}} = KroneckerMatrix{T,A,B}
506596const KroneckerEye{T,A<: AbstractMatrix{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
507597const EyeEye{T,A<: Eye{T} ,B<: Eye{T} } = KroneckerMatrix{T,A,B}
508598
599+ using DerivableInterfaces: DerivableInterfaces, zero!
600+ function DerivableInterfaces. zero! (a:: EyeKronecker )
601+ zero! (a. b)
602+ return a
603+ end
604+ function DerivableInterfaces. zero! (a:: KroneckerEye )
605+ zero! (a. a)
606+ return a
607+ end
608+ function DerivableInterfaces. zero! (a:: EyeEye )
609+ return throw (ArgumentError (" Can't zero out `Eye ⊗ Eye`." ))
610+ end
611+
509612function Base.:* (a:: Number , b:: EyeKronecker )
510613 return b. a ⊗ (a * b. b)
511614end
@@ -580,29 +683,44 @@ end
580683function Base. map! (:: typeof (identity), dest:: EyeEye , a:: EyeEye )
581684 return error (" Can't write in-place." )
582685end
583- function Base. map! (f:: typeof (+ ), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
584- if dest. a ≠ a. a ≠ b. a
585- throw (
586- ArgumentError (
587- " KroneckerArray addition is only supported when the first or second arguments match." ,
588- ),
589- )
686+ for f in [:+ , :- ]
687+ @eval begin
688+ function Base. map! (:: typeof ($ f), dest:: EyeKronecker , a:: EyeKronecker , b:: EyeKronecker )
689+ if dest. a ≠ a. a ≠ b. a
690+ throw (
691+ ArgumentError (
692+ " KroneckerArray addition is only supported when the first or second arguments match." ,
693+ ),
694+ )
695+ end
696+ map! ($ f, dest. b, a. b, b. b)
697+ return dest
698+ end
699+ function Base. map! (:: typeof ($ f), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
700+ if dest. b ≠ a. b ≠ b. b
701+ throw (
702+ ArgumentError (
703+ " KroneckerArray addition is only supported when the first or second arguments match." ,
704+ ),
705+ )
706+ end
707+ map! ($ f, dest. a, a. a, b. a)
708+ return dest
709+ end
710+ function Base. map! (:: typeof ($ f), dest:: EyeEye , a:: EyeEye , b:: EyeEye )
711+ return error (" Can't write in-place." )
712+ end
590713 end
591- map! (f, dest. b, a. b, b. b)
714+ end
715+ function Base. map! (f:: typeof (- ), dest:: EyeKronecker , a:: EyeKronecker )
716+ map! (f, dest. b, a. b)
592717 return dest
593718end
594- function Base. map! (f:: typeof (+ ), dest:: KroneckerEye , a:: KroneckerEye , b:: KroneckerEye )
595- if dest. b ≠ a. b ≠ b. b
596- throw (
597- ArgumentError (
598- " KroneckerArray addition is only supported when the first or second arguments match." ,
599- ),
600- )
601- end
602- map! (f, dest. a, a. a, b. a)
719+ function Base. map! (f:: typeof (- ), dest:: KroneckerEye , a:: KroneckerEye )
720+ map! (f, dest. a, a. a)
603721 return dest
604722end
605- function Base. map! (f:: typeof (+ ), dest:: EyeEye , a:: EyeEye , b :: EyeEye )
723+ function Base. map! (f:: typeof (- ), dest:: EyeEye , a:: EyeEye )
606724 return error (" Can't write in-place." )
607725end
608726function Base. map! (f:: Base.Fix1{typeof(*),<:Number} , dest:: EyeKronecker , a:: EyeKronecker )
@@ -812,6 +930,74 @@ const SquareEyeKronecker{T,A<:SquareEye{T},B<:AbstractMatrix{T}} = KroneckerMatr
812930const KroneckerSquareEye{T,A<: AbstractMatrix{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
813931const SquareEyeSquareEye{T,A<: SquareEye{T} ,B<: SquareEye{T} } = KroneckerMatrix{T,A,B}
814932
933+ # Special case of similar for `SquareEye ⊗ A` and `A ⊗ SquareEye`.
934+ function Base. similar (
935+ a:: SquareEyeKronecker ,
936+ elt:: Type ,
937+ axs:: Tuple {
938+ CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
939+ },
940+ )
941+ ax_a = map (ax -> ax. product. a, axs)
942+ ax_b = map (ax -> ax. product. b, axs)
943+ eye_ax_a = (only (unique (ax_a)),)
944+ return Eye {elt} (eye_ax_a) ⊗ similar (a. b, elt, ax_b)
945+ end
946+ function Base. similar (
947+ a:: KroneckerSquareEye ,
948+ elt:: Type ,
949+ axs:: Tuple {
950+ CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
951+ },
952+ )
953+ ax_a = map (ax -> ax. product. a, axs)
954+ ax_b = map (ax -> ax. product. b, axs)
955+ eye_ax_b = (only (unique (ax_b)),)
956+ return similar (a. a, elt, ax_a) ⊗ Eye {elt} (eye_ax_b)
957+ end
958+ function Base. similar (
959+ a:: SquareEyeSquareEye ,
960+ elt:: Type ,
961+ axs:: Tuple {
962+ CartesianProductUnitRange{<: Integer },Vararg{CartesianProductUnitRange{<: Integer }}
963+ },
964+ )
965+ ax_a = map (ax -> ax. product. a, axs)
966+ ax_b = map (ax -> ax. product. b, axs)
967+ eye_ax_a = (only (unique (ax_a)),)
968+ eye_ax_b = (only (unique (ax_b)),)
969+ return Eye {elt} (eye_ax_a) ⊗ Eye {elt} (eye_ax_b)
970+ end
971+
972+ function Base. similar (
973+ arrayt:: Type{<:SquareEyeKronecker{<:Any,<:Any,A}} ,
974+ axs:: NTuple{2,CartesianProductUnitRange{<:Integer}} ,
975+ ) where {A}
976+ ax_a = map (ax -> ax. product. a, axs)
977+ ax_b = map (ax -> ax. product. b, axs)
978+ eye_ax_a = (only (unique (ax_a)),)
979+ return Eye {eltype(arrayt)} (eye_ax_a) ⊗ similar (A, ax_b)
980+ end
981+ function Base. similar (
982+ arrayt:: Type{<:KroneckerSquareEye{<:Any,A}} ,
983+ axs:: NTuple{2,CartesianProductUnitRange{<:Integer}} ,
984+ ) where {A}
985+ ax_a = map (ax -> ax. product. a, axs)
986+ ax_b = map (ax -> ax. product. b, axs)
987+ eye_ax_b = (only (unique (ax_b)),)
988+ return similar (A, ax_a) ⊗ Eye {eltype(arrayt)} (eye_ax_b)
989+ end
990+ function Base. similar (
991+ arrayt:: Type{<:SquareEyeSquareEye} , axs:: NTuple{2,CartesianProductUnitRange{<:Integer}}
992+ )
993+ elt = eltype (arrayt)
994+ ax_a = map (ax -> ax. product. a, axs)
995+ ax_b = map (ax -> ax. product. b, axs)
996+ eye_ax_a = (only (unique (ax_a)),)
997+ eye_ax_b = (only (unique (ax_b)),)
998+ return Eye {elt} (eye_ax_a) ⊗ Eye {elt} (eye_ax_b)
999+ end
1000+
8151001struct SquareEyeAlgorithm{KWargs<: NamedTuple } <: AbstractAlgorithm
8161002 kwargs:: KWargs
8171003end
@@ -884,8 +1070,6 @@ for f in [:left_null!, :right_null!]
8841070 end
8851071end
8861072for f in [
887- :eig_full! ,
888- :eigh_full! ,
8891073 :qr_compact! ,
8901074 :qr_full! ,
8911075 :left_orth! ,
@@ -900,10 +1084,14 @@ for f in [
9001084 _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a)
9011085 end
9021086end
1087+ _initialize_output_squareeye (:: typeof (eig_full!), a:: SquareEye ) = complex .((a, a))
1088+ _initialize_output_squareeye (:: typeof (eig_full!), a:: SquareEye , alg) = complex .((a, a))
1089+ _initialize_output_squareeye (:: typeof (eigh_full!), a:: SquareEye ) = (real (a), a)
1090+ _initialize_output_squareeye (:: typeof (eigh_full!), a:: SquareEye , alg) = (real (a), a)
9031091for f in [:svd_compact! , :svd_full! ]
9041092 @eval begin
905- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = (a, a , a)
906- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, a , a)
1093+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = (a, real (a) , a)
1094+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = (a, real (a) , a)
9071095 end
9081096end
9091097
@@ -987,10 +1175,12 @@ function MatrixAlgebraKit.right_null!(
9871175 return throw (MethodError (right_null!, (a, F)))
9881176end
9891177
990- for f in [:eig_vals! , :eigh_vals! , :svd_vals! ]
1178+ _initialize_output_squareeye (:: typeof (eig_vals!), a:: SquareEye ) = parent (a)
1179+ _initialize_output_squareeye (:: typeof (eig_vals!), a:: SquareEye , alg) = parent (a)
1180+ for f in [:eigh_vals! , svd_vals!]
9911181 @eval begin
992- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = parent (a)
993- _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = parent (a)
1182+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye ) = real ( parent (a) )
1183+ _initialize_output_squareeye (:: typeof ($ f), a:: SquareEye , alg) = real ( parent (a) )
9941184 end
9951185end
9961186
0 commit comments