Skip to content

Commit 6a87f32

Browse files
committed
Bugfix in spmm
1 parent 635bd41 commit 6a87f32

File tree

3 files changed

+81
-43
lines changed

3 files changed

+81
-43
lines changed

src/PartitionedArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ export spmm
158158
export spmm!
159159
export spmtm
160160
export spmtm!
161+
export centralize
161162
include("p_sparse_matrix.jl")
162163

163164
export PTimer

src/p_sparse_matrix.jl

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,26 +1571,37 @@ function psparse_consistent_impl(
15711571
own_to_global_row = own_to_global(rows_co)
15721572
own_to_global_col = own_to_global(cols_fa)
15731573
ghost_to_global_col = ghost_to_global(cols_fa)
1574-
li_to_p = zeros(Int32,size(A,1))
1574+
nl = size(A,1)
1575+
li_to_ps_ptrs = zeros(Int32,nl+1)
15751576
for p in 1:length(lids_snd)
1576-
li_to_p[lids_snd[p]] .= p
1577+
for li in lids_snd[p]
1578+
li_to_ps_ptrs[li+1] += 1
1579+
end
1580+
end
1581+
length_to_ptrs!(li_to_ps_ptrs)
1582+
ndata = li_to_ps_ptrs[end]-1
1583+
li_to_ps_data = zeros(Int32,ndata)
1584+
for p in 1:length(lids_snd)
1585+
for li in lids_snd[p]
1586+
q = li_to_ps_ptrs[li]
1587+
li_to_ps_data[q] = p
1588+
li_to_ps_ptrs[li] = q + 1
1589+
end
15771590
end
1591+
rewind_ptrs!(li_to_ps_ptrs)
1592+
li_to_ps = JaggedArray(li_to_ps_data,li_to_ps_ptrs)
15781593
ptrs = zeros(Int32,length(parts_snd)+1)
15791594
for (i,j,v) in nziterator(A.blocks.own_own)
15801595
li = own_to_local_row[i]
1581-
p = li_to_p[li]
1582-
if p == 0
1583-
continue
1596+
for p in li_to_ps[li]
1597+
ptrs[p+1] += 1
15841598
end
1585-
ptrs[p+1] += 1
15861599
end
15871600
for (i,j,v) in nziterator(A.blocks.own_ghost)
15881601
li = own_to_local_row[i]
1589-
p = li_to_p[li]
1590-
if p == 0
1591-
continue
1602+
for p in li_to_ps[li]
1603+
ptrs[p+1] += 1
15921604
end
1593-
ptrs[p+1] += 1
15941605
end
15951606
length_to_ptrs!(ptrs)
15961607
ndata = ptrs[end]-1
@@ -1601,30 +1612,26 @@ function psparse_consistent_impl(
16011612
k_snd = JaggedArray(zeros(Int32,ndata),ptrs)
16021613
for (k,(i,j,v)) in enumerate(nziterator(A.blocks.own_own))
16031614
li = own_to_local_row[i]
1604-
p = li_to_p[li]
1605-
if p == 0
1606-
continue
1615+
for p in li_to_ps[li]
1616+
q = ptrs[p]
1617+
I_snd.data[q] = own_to_global_row[i]
1618+
J_snd.data[q] = own_to_global_col[j]
1619+
V_snd.data[q] = v
1620+
k_snd.data[q] = k
1621+
ptrs[p] += 1
16071622
end
1608-
q = ptrs[p]
1609-
I_snd.data[q] = own_to_global_row[i]
1610-
J_snd.data[q] = own_to_global_col[j]
1611-
V_snd.data[q] = v
1612-
k_snd.data[q] = k
1613-
ptrs[p] += 1
16141623
end
16151624
nnz_own_own = nnz(A.blocks.own_own)
16161625
for (k,(i,j,v)) in enumerate(nziterator(A.blocks.own_ghost))
16171626
li = own_to_local_row[i]
1618-
p = li_to_p[li]
1619-
if p == 0
1620-
continue
1627+
for p in li_to_ps[li]
1628+
q = ptrs[p]
1629+
I_snd.data[q] = own_to_global_row[i]
1630+
J_snd.data[q] = ghost_to_global_col[j]
1631+
V_snd.data[q] = v
1632+
k_snd.data[q] = k+nnz_own_own
1633+
ptrs[p] += 1
16211634
end
1622-
q = ptrs[p]
1623-
I_snd.data[q] = own_to_global_row[i]
1624-
J_snd.data[q] = ghost_to_global_col[j]
1625-
V_snd.data[q] = v
1626-
k_snd.data[q] = k+nnz_own_own
1627-
ptrs[p] += 1
16281635
end
16291636
rewind_ptrs!(ptrs)
16301637
cache_snd = (;parts_snd,lids_snd,I_snd,J_snd,V_snd,k_snd)
@@ -1634,12 +1641,12 @@ function psparse_consistent_impl(
16341641
cache_rcv = (;parts_rcv,lids_rcv,I_rcv,J_rcv,V_rcv)
16351642
cache_rcv
16361643
end
1637-
function finalize(A,cache_snd,cache_rcv,rows_co,cols_fa)
1644+
function finalize(A,cache_snd,cache_rcv,rows_co,cols_fa,cols_co)
16381645
I_rcv_data = cache_rcv.I_rcv.data
16391646
J_rcv_data = cache_rcv.J_rcv.data
16401647
V_rcv_data = cache_rcv.V_rcv.data
1641-
global_to_own_col = global_to_own(cols_fa)
1642-
global_to_ghost_col = global_to_ghost(cols_fa)
1648+
global_to_own_col = global_to_own(cols_co)
1649+
global_to_ghost_col = global_to_ghost(cols_co)
16431650
is_own = findall(j->global_to_own_col[j]!=0,J_rcv_data)
16441651
is_ghost = findall(j->global_to_ghost_col[j]!=0,J_rcv_data)
16451652
I_rcv_own = I_rcv_data[is_own]
@@ -1650,20 +1657,24 @@ function psparse_consistent_impl(
16501657
V_rcv_ghost = V_rcv_data[is_ghost]
16511658
map_global_to_ghost!(I_rcv_own,rows_co)
16521659
map_global_to_ghost!(I_rcv_ghost,rows_co)
1653-
map_global_to_own!(J_rcv_own,cols_fa)
1654-
map_global_to_ghost!(J_rcv_ghost,cols_fa)
1655-
own_own = A.blocks.own_own
1656-
own_ghost = A.blocks.own_ghost
1660+
map_global_to_own!(J_rcv_own,cols_co)
1661+
map_global_to_ghost!(J_rcv_ghost,cols_co)
1662+
I2,J2,V2 = findnz(A.blocks.own_ghost)
1663+
map_ghost_to_global!(J2,cols_fa)
1664+
map_global_to_ghost!(J2,cols_co)
1665+
n_own_rows = own_length(rows_co)
16571666
n_ghost_rows = ghost_length(rows_co)
1658-
n_own_cols = own_length(cols_fa)
1659-
n_ghost_cols = ghost_length(cols_fa)
1667+
n_own_cols = own_length(cols_co)
1668+
n_ghost_cols = ghost_length(cols_co)
16601669
TA = typeof(A.blocks.ghost_own)
1670+
own_own = A.blocks.own_own
1671+
own_ghost = compresscoo(TA,I2,J2,V2,n_own_rows,n_ghost_cols)
16611672
ghost_own = compresscoo(TA,I_rcv_own,J_rcv_own,V_rcv_own,n_ghost_rows,n_own_cols)
16621673
ghost_ghost = compresscoo(TA,I_rcv_ghost,J_rcv_ghost,V_rcv_ghost,n_ghost_rows,n_ghost_cols)
16631674
K_own = precompute_nzindex(ghost_own,I_rcv_own,J_rcv_own)
16641675
K_ghost = precompute_nzindex(ghost_ghost,I_rcv_ghost,J_rcv_ghost)
16651676
blocks = split_matrix_blocks(own_own,own_ghost,ghost_own,ghost_ghost)
1666-
values = split_matrix(blocks,local_permutation(rows_co),local_permutation(cols_fa))
1677+
values = split_matrix(blocks,local_permutation(rows_co),local_permutation(cols_co))
16671678
k_snd = cache_snd.k_snd
16681679
V_snd = cache_snd.V_snd
16691680
V_rcv = cache_rcv.V_rcv
@@ -1690,9 +1701,12 @@ function psparse_consistent_impl(
16901701
I_rcv = fetch(t_I)
16911702
J_rcv = fetch(t_J)
16921703
V_rcv = fetch(t_V)
1704+
J_rcv_data = map(x->x.data,J_rcv)
1705+
J_rcv_owner = find_owner(cols_fa,J_rcv_data)
1706+
cols_co = map(union_ghost,cols_fa,J_rcv_data,J_rcv_owner)
16931707
cache_rcv = map(setup_rcv,parts_rcv,lids_rcv,I_rcv,J_rcv,V_rcv)
1694-
values,cache = map(finalize,partition(A),cache_snd,cache_rcv,rows_co,cols_fa) |> tuple_of_arrays
1695-
B = PSparseMatrix(values,rows_co,cols_fa,A.assembled)
1708+
values,cache = map(finalize,partition(A),cache_snd,cache_rcv,rows_co,cols_fa,cols_co) |> tuple_of_arrays
1709+
B = PSparseMatrix(values,rows_co,cols_co,A.assembled)
16961710
if val_parameter(reuse) == false
16971711
B
16981712
else
@@ -1932,7 +1946,7 @@ function spmm(A::PSparseMatrix,B::PSparseMatrix;reuse=Val(false))
19321946
C,cacheC = consistent(B,col_partition;reuse=true) |> fetch
19331947
D_partition,cacheD = map((args...)->spmm(args...;reuse=true),partition(A),partition(C)) |> tuple_of_arrays
19341948
assembled = true
1935-
D = PSparseMatrix(D_partition,partition(axes(A,1)),partition(axes(B,2)),assembled)
1949+
D = PSparseMatrix(D_partition,partition(axes(A,1)),partition(axes(C,2)),assembled)
19361950
if val_parameter(reuse)
19371951
cache = (C,cacheC,cacheD)
19381952
return D,cache
@@ -2056,6 +2070,7 @@ end
20562070
repartition(A::PSparseMatrix,new_rows,new_cols;reuse=false)
20572071
"""
20582072
function repartition(A::PSparseMatrix,new_rows,new_cols;reuse=Val(false))
2073+
@assert A.assembled "repartition on a sub-assembled matrix not implemented yet"
20592074
function prepare_triplets(A_own_own,A_own_ghost,A_rows,A_cols)
20602075
I1,J1,V1 = findnz(A_own_own)
20612076
I2,J2,V2 = findnz(A_own_ghost)
@@ -2146,6 +2161,15 @@ function repartition!(B::PSparseMatrix,c::PVector,A::PSparseMatrix,b::PVector,ca
21462161
end
21472162
end
21482163

2164+
function centralize(A::PSparseMatrix)
2165+
m,n = size(A)
2166+
ranks = linear_indices(partition(A))
2167+
rows_trivial = trivial_partition(ranks,m)
2168+
cols_trivial = trivial_partition(ranks,n)
2169+
a_in_main = repartition(A,rows_trivial,cols_trivial) |> fetch
2170+
own_own_values(a_in_main) |> multicast |> getany
2171+
end
2172+
21492173
"""
21502174
psystem(I,J,V,I2,V2,rows,cols;kwargs...)
21512175
"""

test/p_sparse_matrix_tests.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ function p_sparse_matrix_tests(distribute)
157157
A = psparse(I,J,V,row_partition,col_partition,split_format=true,assemble=false) |> fetch
158158
A = psparse(I,J,V,row_partition,col_partition,split_format=true,assemble=true) |> fetch
159159
A = psparse(I,J,V,row_partition,col_partition) |> fetch
160-
display(A)
160+
centralize(A) |> display
161+
B = A*A
162+
@test centralize(B) == centralize(A)*centralize(A)
161163
# TODO Assembly in non-split_format format not yet implemented
162164
#A = psparse(I,J,V,row_partition,col_partition,split_format=false,assemble=true) |> fetch
163165

@@ -252,7 +254,6 @@ function p_sparse_matrix_tests(distribute)
252254
r = A*x-y
253255
map(i->fill!(i,100),ghost_values(r))
254256
@test norm(r) < 1.0e-9
255-
display(A)
256257

257258
rows_trivial = trivial_partition(parts,n)
258259
cols_trivial = rows_trivial
@@ -331,18 +332,30 @@ function p_sparse_matrix_tests(distribute)
331332
A = PartitionedArrays.laplace_matrix(nodes_per_dir,parts_per_dir,parts)
332333

333334
B = A*A
335+
A_seq = centralize(A)
336+
@test centralize(B) A_seq*A_seq
337+
334338
B = spmm(A,A)
339+
@test centralize(B) A_seq*A_seq
335340
B,cacheB = spmm(A,A;reuse=true)
336341
spmm!(B,A,A,cacheB)
342+
@test centralize(B) A_seq*A_seq
337343

338344
B = transpose(A)*A
345+
@test centralize(B) transpose(A_seq)*A_seq
346+
339347
B = spmtm(A,A)
340348
B,cacheB = spmtm(A,A;reuse=true)
349+
@test centralize(B) transpose(A_seq)*A_seq
341350
spmtm!(B,A,A,cacheB)
351+
@test centralize(B) transpose(A_seq)*A_seq
342352

343353
C = rap(transpose(A),A,A)
354+
@test centralize(C) transpose(A_seq)*A_seq*A_seq
344355
C,cacheC = rap(transpose(A),A,A;reuse=true)
356+
@test centralize(C) transpose(A_seq)*A_seq*A_seq
345357
rap!(C,transpose(A),A,A,cacheC)
358+
@test centralize(C) transpose(A_seq)*A_seq*A_seq
346359

347360
r = pzeros(partition(axes(A,2)))
348361
x = pones(partition(axes(A,1)))

0 commit comments

Comments
 (0)