Skip to content

Commit 503ca6b

Browse files
committed
Added pvector_from_split_blocks
1 parent f4b9540 commit 503ca6b

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

src/PartitionedArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ export find_local_indices
137137
export SplitVector
138138
export split_vector
139139
export split_vector_blocks
140+
export pvector_from_split_blocks
140141
include("p_vector.jl")
141142

142143
export SplitMatrix

src/p_vector.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,21 +129,29 @@ function split_vector_blocks(own::A,ghost::A) where A
129129
SplitVectorBlocks(own,ghost)
130130
end
131131

132-
struct SplitVector{A,T} <: AbstractVector{T}
132+
struct SplitVector{A,B,T} <: AbstractVector{T}
133133
blocks::SplitVectorBlocks{A}
134-
permutation::UnitRange{Int32}
134+
permutation::B
135135
function SplitVector(
136136
blocks::SplitVectorBlocks{A},permutation) where A
137137
T = eltype(blocks.own)
138-
perm = convert(UnitRange{Int32},permutation)
139-
new{A,T}(blocks,perm)
138+
B = typeof(permutation)
139+
new{A,B,T}(blocks,permutation)
140140
end
141141
end
142142

143-
function split_vector(blocks::SplitVectorBlocks,permutation::UnitRange)
143+
function split_vector(blocks::SplitVectorBlocks,permutation)
144144
SplitVector(blocks,permutation)
145145
end
146146

147+
function split_vector(
148+
own::AbstractVector,
149+
ghost::AbstractVector,
150+
permutation)
151+
blocks = split_vector_blocks(own,ghost)
152+
split_vector(blocks,permutation)
153+
end
154+
147155
Base.IndexStyle(::Type{<:SplitVector}) = IndexLinear()
148156
Base.size(a::SplitVector) = (length(a.blocks.own)+length(a.blocks.ghost),)
149157
function Base.getindex(a::SplitVector,local_id::Int)
@@ -974,6 +982,12 @@ function pvector!(C,V,cache)
974982
end
975983
end
976984

985+
function pvector_from_split_blocks(own,ghost,row_partition)
986+
perms = map(local_permutation,row_partition)
987+
values = map(split_vector,own,ghost,perms)
988+
PVector(values,row_partition)
989+
end
990+
977991
function old_pvector!(I,V,index_partition;kwargs...)
978992
old_pvector!(default_local_values,I,V,index_partition;kwargs...)
979993
end

test/p_vector_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ function p_vector_tests(distribute)
5757

5858
@test a == copy(a)
5959

60+
ac = pvector_from_split_blocks(own_values(aa),ghost_values(aa),row_partition)
61+
@test aa == ac
62+
6063
n = 10
6164
I,V = map(rank) do rank
6265
Random.seed!(rank)

0 commit comments

Comments
 (0)