@@ -4,6 +4,7 @@ using PartitionedArrays
44using PetscCall
55using LinearAlgebra
66using Test
7+ using SparseArrays
78
89function spmv_petsc! (b,A,x)
910 # Convert the input to petsc objects
@@ -57,6 +58,42 @@ function test_spmm_petsc(A,B)
5758 GC. @preserve ownership PetscCall. @check_error_code PetscCall. MatDestroy (mat_C)
5859end
5960
61+ function petsc_coo (petsc_comm,I,J,V,rows,cols)
62+ m = own_length (rows)
63+ n = own_length (cols)
64+ M = PetscCall. PETSC_DECIDE
65+ N = PetscCall. PETSC_DECIDE
66+ I .= I .- 1
67+ J .= J .- 1
68+ ownership = (I,J,V)
69+ ncoo = length (I)
70+ A = Ref {PetscCall.Mat} ()
71+ PetscCall. @check_error_code PetscCall. MatCreate (petsc_comm,A)
72+ PetscCall. @check_error_code PetscCall. MatSetType (A[],PetscCall. MATMPIAIJ)
73+ PetscCall. @check_error_code PetscCall. MatSetSizes (A[],m,n,M,N)
74+ PetscCall. @check_error_code PetscCall. MatSetFromOptions (A[])
75+ PetscCall. @check_error_code PetscCall. MatSetPreallocationCOO (A[],ncoo,I,J)
76+ PetscCall. @check_error_code PetscCall. MatSetValuesCOO (A[],V,PetscCall. ADD_VALUES)
77+ # PetscCall.@check_error_code PetscCall.MatAssemblyBegin(A[],PetscCall.MAT_FINAL_ASSEMBLY)
78+ # PetscCall.@check_error_code PetscCall.MatAssemblyEnd(A[],PetscCall.MAT_FINAL_ASSEMBLY)
79+ GC. @preserve ownership PetscCall. @check_error_code PetscCall. MatDestroy (A)
80+ end
81+
82+ function generate_coo (args... )
83+ A = PartitionedArrays. laplace_matrix (args... )
84+ row_partition = partition (axes (A,1 ))
85+ col_partition = partition (axes (A,2 ))
86+ (I,J,V) = map (partition (A),row_partition,col_partition) do myA,rows,cols
87+ Id,Jd,Vd = findnz (myA. blocks. own_own)
88+ Io,Jo,Vo = findnz (myA. blocks. own_ghost)
89+ myI = vcat (map_own_to_global! (Id,rows),map_ghost_to_global! (Io,rows))
90+ myJ = vcat (map_own_to_global! (Jd,cols),map_ghost_to_global! (Jo,cols))
91+ myV = vcat (Vd,Vo)
92+ (myI,myJ,myV)
93+ end |> tuple_of_arrays
94+ I,J,V,row_partition,col_partition
95+ end
96+
6097function main (distribute,params)
6198 nodes_per_dir = params. nodes_per_dir
6299 parts_per_dir = params. parts_per_dir
@@ -80,6 +117,11 @@ function main(distribute,params)
80117 @test norm (c)/ norm (b1) < tol
81118 B = 2 * A
82119 test_spmm_petsc (A,B)
120+ I,J,V,row_partition,col_partition = generate_coo (nodes_per_dir,parts_per_dir,ranks)
121+ petsc_comm = PetscCall. setup_petsc_comm (ranks)
122+ map (I,J,V,row_partition,col_partition) do args...
123+ petsc_coo (petsc_comm,args... )
124+ end
83125end
84126
85127end # module
0 commit comments