Skip to content

Commit 78d9b63

Browse files
authored
support for sparse complex arrays (#180)
* support for sparse complex arrays
1 parent 6e859c1 commit 78d9b63

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

src/mxarray.jl

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,42 @@ function mxsparse(ty::Type{Float64}, m::Integer, n::Integer, nzmax::Integer)
287287
MxArray(pm)
288288
end
289289

290+
function mxsparse(ty::Type{ComplexF64}, m::Integer, n::Integer, nzmax::Integer)
291+
pm = ccall(mx_create_sparse[], Ptr{Cvoid},
292+
(mwSize, mwSize, mwSize, mxComplexity), m, n, nzmax, mxCOMPLEX)
293+
MxArray(pm)
294+
end
295+
290296
function mxsparse(ty::Type{Bool}, m::Integer, n::Integer, nzmax::Integer)
291297
pm = ccall(mx_create_sparse_logical[], Ptr{Cvoid},
292298
(mwSize, mwSize, mwSize), m, n, nzmax)
293299
MxArray(pm)
294300
end
295301

302+
function _copy_sparse_mat(a::SparseMatrixCSC{V,I}, ir_p::Ptr{mwIndex}, jc_p::Ptr{mwIndex}, pr_p::Ptr{Float64}, pi_p::Ptr{Float64}) where {V<:ComplexF64,I}
303+
colptr::Vector{I} = a.colptr
304+
rinds::Vector{I} = a.rowval
305+
vr::Vector{Float64} = real(a.nzval)
306+
vi::Vector{Float64} = imag(a.nzval)
307+
n::Int = a.n
308+
nnz::Int = length(vr)
309+
310+
# Note: ir and jc contain zero-based indices
311+
312+
ir = unsafe_wrap(Array, ir_p, (nnz,))
313+
for i = 1:nnz
314+
ir[i] = rinds[i] - 1
315+
end
316+
317+
jc = unsafe_wrap(Array, jc_p, (n+1,))
318+
for i = 1:n+1
319+
jc[i] = colptr[i] - 1
320+
end
321+
322+
copyto!(unsafe_wrap(Array, pr_p, (nnz,)), vr)
323+
copyto!(unsafe_wrap(Array, pi_p, (nnz,)), vi)
324+
end
325+
296326
function _copy_sparse_mat(a::SparseMatrixCSC{V,I}, ir_p::Ptr{mwIndex}, jc_p::Ptr{mwIndex}, pr_p::Ptr{V}) where {V,I}
297327
colptr::Vector{I} = a.colptr
298328
rinds::Vector{I} = a.rowval
@@ -315,19 +345,24 @@ function _copy_sparse_mat(a::SparseMatrixCSC{V,I}, ir_p::Ptr{mwIndex}, jc_p::Ptr
315345
copyto!(unsafe_wrap(Array, pr_p, (nnz,)), v)
316346
end
317347

318-
function mxarray(a::SparseMatrixCSC{V,I}) where {V<:Union{Float64,Bool},I}
348+
function mxarray(a::SparseMatrixCSC{V,I}) where {V<:Union{Float64,ComplexF64,Bool},I}
319349
m::Int = a.m
320350
n::Int = a.n
321351
nnz = length(a.nzval)
322352
@assert nnz == a.colptr[n+1]-1
323353

324354
mx = mxsparse(V, m, n, nnz)
325-
326355
ir_p = ccall(mx_get_ir[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
327356
jc_p = ccall(mx_get_jc[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
328-
pr_p = ccall(mx_get_pr[], Ptr{V}, (Ptr{Cvoid},), mx)
329357

330-
_copy_sparse_mat(a, ir_p, jc_p, pr_p)
358+
if V <: ComplexF64
359+
pr_p = ccall(mx_get_pr[], Ptr{Float64}, (Ptr{Cvoid},), mx)
360+
pi_p = ccall(mx_get_pi[], Ptr{Float64}, (Ptr{Cvoid},), mx)
361+
_copy_sparse_mat(a, ir_p, jc_p, pr_p, pi_p)
362+
else
363+
pr_p = ccall(mx_get_pr[], Ptr{V}, (Ptr{Cvoid},), mx)
364+
_copy_sparse_mat(a, ir_p, jc_p, pr_p)
365+
end
331366
return mx
332367
end
333368

@@ -537,7 +572,6 @@ function _jsparse(ty::Type{T}, mx::MxArray) where T<:MxRealNum
537572
n = ncols(mx)
538573
ir_ptr = ccall(mx_get_ir[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
539574
jc_ptr = ccall(mx_get_jc[], Ptr{mwIndex}, (Ptr{Cvoid},), mx)
540-
pr_ptr = ccall(mx_get_pr[], Ptr{T}, (Ptr{Cvoid},), mx)
541575

542576
jc_a::Vector{mwIndex} = unsafe_wrap(Array, jc_ptr, (n+1,))
543577
nnz = jc_a[n+1]
@@ -555,8 +589,15 @@ function _jsparse(ty::Type{T}, mx::MxArray) where T<:MxRealNum
555589
jc[i] = jc_x[i] + 1
556590
end
557591

592+
pr_ptr = ccall(mx_get_pr[], Ptr{T}, (Ptr{Cvoid},), mx)
558593
pr::Vector{T} = copy(unsafe_wrap(Array, pr_ptr, (nnz,)))
559-
return SparseMatrixCSC(m, n, jc, ir, pr)
594+
if is_complex(mx)
595+
pi_ptr = ccall(mx_get_pi[], Ptr{T}, (Ptr{Cvoid},), mx)
596+
pi::Vector{T} = copy(unsafe_wrap(Array, pi_ptr, (nnz,)))
597+
return SparseMatrixCSC(m, n, jc, ir, pr + im.*pi)
598+
else
599+
return SparseMatrixCSC(m, n, jc, ir, pr)
600+
end
560601
end
561602

562603
function jsparse(mx::MxArray)

test/mxarray.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ a2 = jsparse(a_mx)
209209
@test isequal(a2, a)
210210
delete(a_mx)
211211

212+
a = sparse([1.0 1.0im])
213+
a_mx = mxarray(a)
214+
@test is_sparse(a_mx)
215+
@test is_double(a_mx)
216+
@test is_complex(a_mx)
217+
@test nrows(a_mx) == 1
218+
@test ncols(a_mx) == 2
219+
delete(a_mx)
220+
212221
# strings
213222

214223
s = "MATLAB.jl"
@@ -345,6 +354,13 @@ delete(x)
345354
@test isa(y, Array{Float64,3})
346355
@test isequal(y, a)
347356

357+
a = sparse([1.0 2.0im; 0 -1.0im])
358+
a_mx = mxarray(a)
359+
a_jl = jvalue(a_mx)
360+
delete(a_mx)
361+
@test a == a_jl
362+
@test isa(a_jl, SparseMatrixCSC{Complex{Float64}})
363+
348364
a = "MATLAB"
349365
x = mxarray(a)
350366
y = jvalue(x)

0 commit comments

Comments
 (0)