Skip to content

Commit 3858177

Browse files
committed
Testing spmv in petsc
1 parent f7f25ca commit 3858177

File tree

4 files changed

+98
-2
lines changed

4 files changed

+98
-2
lines changed

test/mpi_array/api_test.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
module ApiTests
2+
3+
using Test
4+
using MPI
5+
using PartitionedArrays
6+
7+
repodir = normpath(joinpath(@__DIR__,"..",".."))
8+
9+
defs = joinpath(repodir,"test","mpi_array","api_test_defs.jl")
10+
11+
include(defs)
12+
params = (;nodes_per_dir=(10,10,10),parts_per_dir=(1,1,1))
13+
with_mpi(dist->Defs.main(dist,params))
14+
15+
code = quote
16+
using MPI; MPI.Init()
17+
using PartitionedArrays
18+
include($defs)
19+
params = (;nodes_per_dir=(10,10,10),parts_per_dir=(2,2,2))
20+
with_mpi(dist->Defs.main(dist,params))
21+
end
22+
run(`$(mpiexec()) -np 8 $(Base.julia_cmd()) --project=$repodir -e $code`)
23+
24+
code = quote
25+
using MPI; MPI.Init()
26+
using PartitionedArrays
27+
include($defs)
28+
params = (;nodes_per_dir=(10,10,10),parts_per_dir=(2,4,1))
29+
with_mpi(dist->Defs.main(dist,params))
30+
end
31+
run(`$(mpiexec()) -np 8 $(Base.julia_cmd()) --project=$repodir -e $code`)
32+
33+
end # module

test/mpi_array/api_test_defs.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
module Defs
2+
3+
using PartitionedArrays
4+
using PetscCall
5+
using LinearAlgebra
6+
using Test
7+
8+
function spmv_petsc!(b,A,x)
9+
# Convert the input to petsc objects
10+
mat = Ref{PetscCall.Mat}()
11+
vec_b = Ref{PetscCall.Vec}()
12+
vec_x = Ref{PetscCall.Vec}()
13+
parts = linear_indices(partition(x))
14+
petsc_comm = PetscCall.setup_petsc_comm(parts)
15+
args_A = PetscCall.MatCreateMPIAIJWithSplitArrays_args(A,petsc_comm)
16+
args_b = PetscCall.VecCreateMPIWithArray_args(copy(b),petsc_comm)
17+
args_x = PetscCall.VecCreateMPIWithArray_args(copy(x),petsc_comm)
18+
ownership = (args_A,args_b,args_x)
19+
PetscCall.@check_error_code PetscCall.MatCreateMPIAIJWithSplitArrays(args_A...,mat)
20+
PetscCall.@check_error_code PetscCall.MatAssemblyBegin(mat[],PetscCall.MAT_FINAL_ASSEMBLY)
21+
PetscCall.@check_error_code PetscCall.MatAssemblyEnd(mat[],PetscCall.MAT_FINAL_ASSEMBLY)
22+
PetscCall.@check_error_code PetscCall.VecCreateMPIWithArray(args_b...,vec_b)
23+
PetscCall.@check_error_code PetscCall.VecCreateMPIWithArray(args_x...,vec_x)
24+
# This line does the actual product
25+
PetscCall.@check_error_code PetscCall.MatMult(mat[],vec_x[],vec_b[])
26+
# Move the result back to julia
27+
PetscCall.VecCreateMPIWithArray_args_reversed!(b,args_b)
28+
# Cleanup
29+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.MatDestroy(mat)
30+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.VecDestroy(vec_b)
31+
GC.@preserve ownership PetscCall.@check_error_code PetscCall.VecDestroy(vec_x)
32+
b
33+
end
34+
35+
function main(distribute,params)
36+
nodes_per_dir = params.nodes_per_dir
37+
parts_per_dir = params.parts_per_dir
38+
np = prod(parts_per_dir)
39+
ranks = LinearIndices((np,)) |> distribute
40+
A = PartitionedArrays.laplace_matrix(nodes_per_dir,parts_per_dir,ranks)
41+
rows = partition(axes(A,1))
42+
cols = partition(axes(A,2))
43+
x = pones(cols)
44+
b1 = pzeros(rows)
45+
b2 = pzeros(rows)
46+
mul!(b1,A,x)
47+
if ! PetscCall.initialized()
48+
PetscCall.init()
49+
end
50+
spmv_petsc!(b2,A,x)
51+
c = b1-b2
52+
tol = 1.0e-12
53+
@test norm(b1) > tol
54+
@test norm(b2) > tol
55+
@test norm(c)/norm(b1) < tol
56+
end
57+
58+
end #module

test/mpi_array/ksp_test.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
module KspTests
2+
13
using MPI
24
using Test
35

@@ -19,6 +21,5 @@ end
1921

2022
run(`$mpiexec_cmd -np 3 $(Base.julia_cmd()) --project=$repodir -e $code`)
2123

22-
nothing
23-
24+
end # module
2425

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ module PetscCallTest
33
using PetscCall
44
using Test
55

6+
@testset "API" begin
7+
@testset "PartitionedArrays: MPIArray" begin include("mpi_array/api_test.jl") end
8+
end
9+
610
@testset "KSP" begin
711
@testset "Sequential" begin include("ksp_test.jl") end
812
@testset "PartitionedArrays: DebugArray" begin include("debug_array/ksp_test.jl") end

0 commit comments

Comments
 (0)