Skip to content

Commit 10c2fe9

Browse files
committed
Refactoring of optional arguments in psparse/pvector/psystem
1 parent 9f9f41c commit 10c2fe9

File tree

2 files changed

+153
-62
lines changed

2 files changed

+153
-62
lines changed

src/p_sparse_matrix.jl

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,10 +1056,10 @@ the communications needed in its setup.
10561056
"""
10571057
function psparse(f,I,J,V,rows,cols;
10581058
split_format=true,
1059+
subassembled=false,
10591060
assembled=false,
10601061
assemble=true,
1061-
discover_rows=true,
1062-
discover_cols=true,
1062+
indices = :global,
10631063
restore_ids = true,
10641064
assembly_neighbors_options_rows = (;),
10651065
assembly_neighbors_options_cols = (;),
@@ -1073,50 +1073,106 @@ function psparse(f,I,J,V,rows,cols;
10731073
# Even the matrix compression step could be
10741074
# merged with the assembly step
10751075

1076-
map(I,J) do I,J
1077-
@assert I !== J
1078-
end
1076+
# Checks
1077+
disassembled = (!subassembled && ! assembled) ? true : false
10791078

1080-
if assembled || assemble
1081-
@boundscheck @assert all(i->ghost_length(i)==0,rows)
1079+
@assert indices in (:global,:local)
1080+
if count((subassembled,assembled)) == 2
1081+
error("Only one of the folling flags can be set to true: subassembled, assembled")
1082+
end
1083+
if indices === :global
1084+
map(I,J) do I,J
1085+
@assert I !== J
1086+
end
10821087
end
10831088

1084-
if !assembled && discover_rows
1089+
if disassembled
1090+
# TODO If assemble==true, we can (should) optimize the code
1091+
# to do the conversion from disassembled to (fully) assembled split format
1092+
# in a single shot.
1093+
@assert indices === :global
10851094
I_owner = find_owner(rows,I)
1086-
rows_sa = map(union_ghost,rows,I,I_owner)
1087-
assembly_neighbors(rows_sa;assembly_neighbors_options_rows...)
1088-
else
1089-
rows_sa = rows
1090-
end
1091-
if discover_cols
10921095
J_owner = find_owner(cols,J)
1096+
rows_sa = map(union_ghost,rows,I,I_owner)
10931097
cols_sa = map(union_ghost,cols,J,J_owner)
1098+
assembly_neighbors(rows_sa;assembly_neighbors_options_rows...)
10941099
if ! assemble
1095-
assembly_neighbors(rows_sa;assembly_neighbors_options_cols...)
1100+
# We only need this if we want a subassembled output.
1101+
# For assembled output, this call will be deleted when optimizing
1102+
# the code to do the conversions in a single shot.
1103+
assembly_neighbors(cols_sa;assembly_neighbors_options_cols...)
10961104
end
1097-
else
1105+
map(map_global_to_local!,I,rows_sa)
1106+
map(map_global_to_local!,J,cols_sa)
1107+
values_sa = map(f,I,J,V,map(local_length,rows_sa),map(local_length,cols_sa))
1108+
if val_parameter(reuse)
1109+
K = map(precompute_nzindex,values_sa,I,J)
1110+
end
1111+
if restore_ids
1112+
map(map_local_to_global!,I,rows_sa)
1113+
map(map_local_to_global!,J,cols_sa)
1114+
end
1115+
A = PSparseMatrix(values_sa,rows_sa,cols_sa,assembled)
1116+
if split_format
1117+
B,cacheB = PartitionedArrays.split_format(A;reuse=true)
1118+
else
1119+
B,cacheB = A,nothing
1120+
end
1121+
if assemble
1122+
t = PartitionedArrays.assemble(B,rows;reuse=true,assembly_neighbors_options_cols)
1123+
else
1124+
t = @async B,cacheB
1125+
end
1126+
elseif subassembled
1127+
rows_sa = rows
10981128
cols_sa = cols
1099-
end
1100-
map(map_global_to_local!,I,rows_sa)
1101-
map(map_global_to_local!,J,cols_sa)
1102-
values_sa = map(f,I,J,V,map(local_length,rows_sa),map(local_length,cols_sa))
1103-
if val_parameter(reuse)
1104-
K = map(precompute_nzindex,values_sa,I,J)
1105-
end
1106-
if restore_ids
1107-
map(map_local_to_global!,I,rows_sa)
1108-
map(map_local_to_global!,J,cols_sa)
1109-
end
1110-
A = PSparseMatrix(values_sa,rows_sa,cols_sa,assembled)
1111-
if split_format
1112-
B,cacheB = PartitionedArrays.split_format(A;reuse=true)
1113-
else
1114-
B,cacheB = A,nothing
1115-
end
1116-
if assemble
1117-
t = PartitionedArrays.assemble(B,rows;reuse=true,assembly_neighbors_options_cols)
1118-
else
1129+
if indices === :global
1130+
map(map_global_to_local!,I,rows_sa)
1131+
map(map_global_to_local!,J,cols_sa)
1132+
end
1133+
values_sa = map(f,I,J,V,map(local_length,rows_sa),map(local_length,cols_sa))
1134+
if val_parameter(reuse)
1135+
K = map(precompute_nzindex,values_sa,I,J)
1136+
end
1137+
if indices === :global && restore_ids
1138+
map(map_local_to_global!,I,rows_sa)
1139+
map(map_local_to_global!,J,cols_sa)
1140+
end
1141+
A = PSparseMatrix(values_sa,rows_sa,cols_sa,assembled)
1142+
if split_format
1143+
B,cacheB = PartitionedArrays.split_format(A;reuse=true)
1144+
else
1145+
B,cacheB = A,nothing
1146+
end
1147+
if assemble
1148+
t = PartitionedArrays.assemble(B,rows;reuse=true,assembly_neighbors_options_cols)
1149+
else
1150+
t = @async B,cacheB
1151+
end
1152+
elseif assembled
1153+
rows_fa = rows
1154+
cols_fa = cols
1155+
if indices === :global
1156+
map(map_global_to_local!,I,rows_fa)
1157+
map(map_global_to_local!,J,cols_fa)
1158+
end
1159+
values_fa = map(f,I,J,V,map(local_length,rows_fa),map(local_length,cols_fa))
1160+
if val_parameter(reuse)
1161+
K = map(precompute_nzindex,values_fa,I,J)
1162+
end
1163+
if indices === :global && restore_ids
1164+
map(map_local_to_global!,I,rows_fa)
1165+
map(map_local_to_global!,J,cols_fa)
1166+
end
1167+
A = PSparseMatrix(values_fa,rows_fa,cols_fa,assembled)
1168+
if split_format
1169+
B,cacheB = PartitionedArrays.split_format(A;reuse=true)
1170+
else
1171+
B,cacheB = A,nothing
1172+
end
11191173
t = @async B,cacheB
1174+
else
1175+
error("This line should not be reached")
11201176
end
11211177
if val_parameter(reuse) == false
11221178
return @async begin
@@ -1842,14 +1898,14 @@ end
18421898
psystem(I,J,V,I2,V2,rows,cols;kwargs...)
18431899
"""
18441900
function psystem(I,J,V,I2,V2,rows,cols;
1901+
subassembled=false,
18451902
assembled=false,
18461903
assemble=true,
1847-
discover_rows=true,
1848-
discover_cols=true,
1904+
indices = :global,
18491905
restore_ids = true,
1850-
reuse=Val(false),
18511906
assembly_neighbors_options_rows = (;),
1852-
assembly_neighbors_options_cols = (;)
1907+
assembly_neighbors_options_cols = (;),
1908+
reuse=Val(false)
18531909
)
18541910

18551911
# TODO this is just a reference implementation
@@ -1858,19 +1914,18 @@ function psystem(I,J,V,I2,V2,rows,cols;
18581914
# that we want to generate a matrix and a vector
18591915

18601916
t1 = psparse(I,J,V,rows,cols;
1917+
subassembled,
18611918
assembled,
18621919
assemble,
1863-
discover_rows,
1864-
discover_cols,
18651920
restore_ids,
18661921
assembly_neighbors_options_rows,
18671922
assembly_neighbors_options_cols,
18681923
reuse=true)
18691924

18701925
t2 = pvector(I2,V2,rows;
1926+
subassembled,
18711927
assembled,
18721928
assemble,
1873-
discover_rows,
18741929
restore_ids,
18751930
assembly_neighbors_options_rows,
18761931
reuse=true)

src/p_vector.jl

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -573,38 +573,75 @@ instance of [`PVector`](@ref) allowing latency hiding while performing
573573
the communications needed in its setup.
574574
"""
575575
function pvector(f,I,V,rows;
576+
subassembled=false,
576577
assembled=false,
577578
assemble=true,
578-
discover_rows=true,
579579
restore_ids = true,
580+
indices = :global,
580581
reuse=Val(false),
581582
assembly_neighbors_options_rows = (;)
582583
)
583584

584-
if assembled || assemble
585-
@boundscheck @assert all(i->ghost_length(i)==0,rows)
585+
# Checks
586+
disassembled = (!subassembled && ! assembled) ? true : false
587+
@assert indices in (:global,:local)
588+
if count((subassembled,assembled)) == 2
589+
error("Only one of the folling flags can be set to true: subassembled, assembled")
586590
end
587591

588-
if !assembled && discover_rows
592+
if disassembled
593+
@assert indices === :global
589594
I_owner = find_owner(rows,I)
590595
rows_sa = map(union_ghost,rows,I,I_owner)
591596
assembly_neighbors(rows_sa;assembly_neighbors_options_rows...)
592-
else
597+
map(map_global_to_local!,I,rows_sa)
598+
values_sa = map(f,I,V,map(local_length,rows_sa))
599+
if val_parameter(reuse)
600+
K = map(copy,I)
601+
end
602+
if restore_ids
603+
map(map_local_to_global!,I,rows_sa)
604+
end
605+
A = PVector(values_sa,rows_sa)
606+
if assemble
607+
t = PartitionedArrays.assemble(A,rows;reuse=true)
608+
else
609+
t = @async A, nothing
610+
end
611+
elseif subassembled
593612
rows_sa = rows
594-
end
595-
map(map_global_to_local!,I,rows_sa)
596-
values_sa = map(f,I,V,map(local_length,rows_sa))
597-
if val_parameter(reuse)
598-
K = map(copy,I)
599-
end
600-
if restore_ids
601-
map(map_local_to_global!,I,rows_sa)
602-
end
603-
A = PVector(values_sa,rows_sa)
604-
if !assembled && assemble
605-
t = PartitionedArrays.assemble(A,rows;reuse=true)
613+
if indices === :global
614+
map(map_global_to_local!,I,rows_sa)
615+
end
616+
values_sa = map(f,I,V,map(local_length,rows_sa))
617+
if val_parameter(reuse)
618+
K = map(copy,I)
619+
end
620+
if indices === :global && restore_ids
621+
map(map_local_to_global!,I,rows_sa)
622+
end
623+
A = PVector(values_sa,rows_sa)
624+
if assemble
625+
t = PartitionedArrays.assemble(A,rows;reuse=true)
626+
else
627+
t = @async A, nothing
628+
end
629+
elseif assembled
630+
rows_fa = rows
631+
if indices === :global
632+
map(map_global_to_local!,I,rows_fa)
633+
end
634+
values_fa = map(f,I,V,map(local_length,rows_fa))
635+
if val_parameter(reuse)
636+
K = map(copy,I)
637+
end
638+
if indices === :global && restore_ids
639+
map(map_local_to_global!,I,rows_fa)
640+
end
641+
A = PVector(values_fa,rows_fa)
642+
t = @async A, nothing
606643
else
607-
t = @async A,nothing
644+
error("This line should not be reached")
608645
end
609646
if val_parameter(reuse) == false
610647
return @async begin
@@ -641,7 +678,6 @@ function pvector!(B,V,cache)
641678
end
642679
end
643680

644-
645681
function old_pvector!(I,V,index_partition;kwargs...)
646682
old_pvector!(default_local_values,I,V,index_partition;kwargs...)
647683
end

0 commit comments

Comments
 (0)