Skip to content

Commit 9f45d82

Browse files
authored
Better indexing into a KronExpansion (#207)
* Better indexing into a KronExpansion * Update bases.jl * increase coverage * v0.19.6
1 parent a1e2a62 commit 9f45d82

File tree

6 files changed

+58
-11
lines changed

6 files changed

+58
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ContinuumArrays"
22
uuid = "7ae1f121-cc2c-504b-ac30-9b923412ae5c"
3-
version = "0.19.5"
3+
version = "0.19.6"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"

src/bases/bases.jl

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -581,15 +581,34 @@ end
581581

582582

583583
# we represent as a Mul with a banded matrix
584-
# sublayout(::AbstractBasisLayout, ::Type{<:Tuple{<:Inclusion,<:Integer}}) = SubBasisLayout()
585-
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractVector}}) = SubBasisLayout()
586-
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{<:AbstractAffineQuasiVector,<:AbstractVector}}) = MappedBasisLayout()
587-
sublayout(::WeightedBasisLayouts, ::Type{<:Tuple{<:AbstractAffineQuasiVector,<:AbstractVector}}) = MappedWeightedBasisLayout()
588-
sublayout(::WeightedBasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractVector}}) = SubWeightedBasisLayout()
589-
sublayout(::MappedWeightedBasisLayout, ::Type{<:Tuple{<:Inclusion,<:AbstractVector}}) = MappedWeightedBasisLayout()
584+
# sublayout(::AbstractBasisLayout, ::Type{Tuple{Inclusion,Integer}}) = SubBasisLayout()
585+
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{Inclusion,AbstractVector}}) = SubBasisLayout()
586+
sublayout(::AbstractBasisLayout, ::Type{<:Tuple{AbstractAffineQuasiVector,AbstractVector}}) = MappedBasisLayout()
587+
sublayout(::WeightedBasisLayouts, ::Type{<:Tuple{AbstractAffineQuasiVector,AbstractVector}}) = MappedWeightedBasisLayout()
588+
sublayout(::WeightedBasisLayout, ::Type{<:Tuple{Inclusion,AbstractVector}}) = SubWeightedBasisLayout()
589+
# sublayout(::MappedWeightedBasisLayout, ::Type{<:Tuple{Inclusion,AbstractVector}}) = MappedWeightedBasisLayout() # not used
590+
sublayout(lay::ExpansionLayout, ::Type{<:Tuple{Inclusion,Integer}}) = lay
591+
sublayout(lay::ExpansionLayout, ::Type{<:Tuple{Inclusion,AbstractVector}}) = lay
592+
593+
594+
sub_basis_layout(_, P, j) = basis(P) # TODO: restrict to ExpansionLayout?
595+
function basis(V::SubQuasiArray{<:Any, N, <:Any, <:Tuple{Inclusion,Any}}) where N
596+
P = parent(V)
597+
_,j = parentindices(V)
598+
sub_basis_layout(MemoryLayout(P), P, j)
599+
end
600+
601+
602+
sub_coefficients_layout(_, P, j) = coefficients(P)[:,j] # TODO: restrict to ExpansionLayout?
603+
function coefficients(V::SubQuasiArray{<:Any, N, <:Any, <:Tuple{Inclusion,Any}}) where N
604+
P = parent(V)
605+
_,j = parentindices(V)
606+
sub_coefficients_layout(MemoryLayout(P), P, j)
607+
end
590608

591609
@inline sub_materialize(::AbstractBasisLayout, V::AbstractQuasiArray) = V
592610
@inline sub_materialize(::AbstractBasisLayout, V::AbstractArray) = V
611+
@inline sub_materialize(::ExpansionLayout, V::AbstractQuasiArray) = basis(V) * coefficients(V)
593612

594613
demap(x) = x
595614
demap(x::BroadcastQuasiArray) = BroadcastQuasiArray(x.f, map(demap, arguments(x))...)

src/bases/basiskron.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,17 @@ is a MemoryLayout corresponding to a quasi-matrix corresponding to the 2D expans
99
"""
1010
struct KronExpansionLayout{LayA, LayB} <: AbstractLazyLayout end
1111
applylayout(::Type{typeof(*)}, ::LayA, ::CoefficientLayouts, ::AdjointBasisLayout{LayB}) where {LayA <: AbstractBasisLayout, LayB <: AbstractBasisLayout} = KronExpansionLayout{LayA,LayB}()
12+
1213
sublayout(::KronExpansionLayout, inds) = sublayout(ApplyLayout{typeof(*)}(), inds)
14+
sublayout(::KronExpansionLayout{LayA, LayB}, inds::Type{<:Tuple{Inclusion,AbstractVector}}) where {LayA,LayB} = ExpansionLayout{LayA}()
15+
16+
17+
sub_basis_layout(::KronExpansionLayout, P, j) = first(arguments(P))
18+
19+
20+
function sub_coefficients_layout(::KronExpansionLayout, P, j)
21+
_,X,Bt = arguments(P)
22+
X * Bt[:,j]
23+
end
24+
1325
sum_layout(::KronExpansionLayout, F, dims...) = sum_layout(ApplyLayout{typeof(*)}(), F, dims...)

test/test_basiskron.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@ end
1818
@test sum(F; dims=2)[0.1,1] 7.3
1919

2020
@test F[0.1,:][0.2] F[:,0.2][0.1] F[0.1,0.2]
21+
22+
@test F[:,[0.1,0.2]][0.3,:] F[0.3,[0.1,0.2]] F[0.3,:][[0.1,0.2]]
23+
@test F[[0.1,0.2],:][:,0.3] F[[0.1,0.2],0.3] F[:,0.3][[0.1,0.2]]
24+
@test F[[0.1,0.2],[0.3,0.4]] F[[0.1,0.2],:][:,[0.3,0.4]] F[:,[0.3,0.4]][[0.1,0.2],:]
2125
end

test/test_chebyshev.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ Base.:(==)(::FooBasis, ::FooBasis) = true
117117
@test MemoryLayout(w[y] .* T[y,:]) isa MappedWeightedBasisLayout
118118
@test wT[y,:][[0.1,0.2],1:5] == (w[y] .* T[y,:])[[0.1,0.2],1:5] == (w .* T[:,1:5])[y,:][[0.1,0.2],:]
119119
@test MemoryLayout(wT[y,1:3]) isa MappedWeightedBasisLayout
120+
@test MemoryLayout(wT[y,1:3][:,1:2]) isa MappedWeightedBasisLayout
120121
@test wT[y,1:3][[0.1,0.2],1:2] == wT[y[[0.1,0.2]],1:2]
121122

122123
@test T[y,:]'T[y,:] grammatrix(T[y,:]) (T'T)/2

test/test_splines.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,20 @@ import ContinuumArrays: basis, AdjointBasisLayout, ExpansionLayout, BasisLayout,
655655
@test expand(L,f)[0.1] 1
656656
end
657657

658-
@testset "vec" begin
659-
L = LinearSpline([-1,0,1])
660-
F = L * randn(3,1)
661-
@test vec(F)[0.1] == F[0.1,1]
658+
@testset "matrix coefficients" begin
659+
@testset "vec" begin
660+
L = LinearSpline([-1,0,1])
661+
F = L * randn(3,1)
662+
@test vec(F)[0.1] == F[0.1,1]
663+
end
664+
665+
@testset "sub" begin
666+
L = LinearSpline([-1,0,1])
667+
F = L * randn(3,2)
668+
@test MemoryLayout(F[:,1]) isa ExpansionLayout
669+
@test MemoryLayout(F[:,1:2]) isa ExpansionLayout
670+
@test F[:,1][0.1] F[0.1,1]
671+
@test F[:,1:2][0.1,:] F[0.1,1:2]
672+
end
662673
end
663674
end

0 commit comments

Comments
 (0)