Skip to content

Commit d545e7f

Browse files
npriyadarshitranslunar
authored andcommitted
fixes inverse_exact methods for dense, yale and list stype (#585)
* Corrected the tests in 00_nmatrix_sec.rb and inverse_exact methods * added inverse_exact# in lib/nmatrix/math.rb * Removed the StorageTypeError * Changed the name of the method to exact_inverse!
1 parent 40cc770 commit d545e7f

File tree

4 files changed

+111
-43
lines changed

4 files changed

+111
-43
lines changed

ext/nmatrix/math.cpp

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ extern "C" {
188188
// Math Functions //
189189
////////////////////
190190

191-
namespace nm {
191+
namespace nm {
192192
namespace math {
193193

194194
/*
@@ -232,8 +232,8 @@ namespace nm {
232232
DType* a = reinterpret_cast<DType*>(storage->a);
233233
IType col_pos = storage->shape[0] + 1;
234234
if (M == 2) {
235-
if (ija[2] - ija[0] == 2) {
236-
*result = a[0] * a[1] - a[col_pos] * a[col_pos+1];
235+
if (ija[2] - ija[0] == 2) {
236+
*result = a[0] * a[1] - a[col_pos] * a[col_pos+1];
237237
}
238238
else { *result = a[0] * a[1]; }
239239
} else if (M == 3) {
@@ -258,7 +258,7 @@ namespace nm {
258258
rb_raise(rb_eArgError, "some value in IJA is incorrect!");
259259
}
260260
}
261-
*result =
261+
*result =
262262
m[0][0] * m[1][1] * m[2][2] + m[0][1] * m[1][2] * m[2][0] + m[0][2] * m[1][0] * m[2][1]
263263
- m[0][0] * m[1][2] * m[2][1] - m[0][1] * m[1][0] * m[2][2] - m[0][2] * m[1][1] * m[2][0];
264264

@@ -270,13 +270,13 @@ namespace nm {
270270
}
271271

272272
/*
273-
* Solve a system of linear equations using forward-substution followed by
273+
* Solve a system of linear equations using forward-substution followed by
274274
* back substution from the LU factorization of the matrix of co-efficients.
275275
* Replaces x_elements with the result. Works only with non-integer, non-object
276276
* data types.
277277
*
278278
* args - r -> The number of rows of the matrix.
279-
* lu_elements -> Elements of the LU decomposition of the co-efficients
279+
* lu_elements -> Elements of the LU decomposition of the co-efficients
280280
* matrix, as a contiguos array.
281281
* b_elements -> Elements of the the right hand sides, as a contiguous array.
282282
* x_elements -> The array that will contain the results of the computation.
@@ -291,7 +291,7 @@ namespace nm {
291291
const DType* b = reinterpret_cast<const DType*>(b_elements);
292292
DType* x = reinterpret_cast<DType*>(x_elements);
293293

294-
for (int i = 0; i < r; ++i) { x[i] = b[i]; }
294+
for (int i = 0; i < r; ++i) { x[i] = b[i]; }
295295
for (int i = 0; i < r; ++i) { // forward substitution loop
296296
ip = pivot[i];
297297
sum = x[ip];
@@ -335,18 +335,18 @@ namespace nm {
335335
for (int row = k + 1; row < M; ++row) {
336336
typename MagnitudeDType<DType>::type big;
337337
big = magnitude( matrix[M*row + k] ); // element below the temp pivot
338-
338+
339339
if ( big > akk ) {
340340
interchange = row;
341-
akk = big;
341+
akk = big;
342342
}
343-
}
343+
}
344344

345345
if (interchange != k) { // check if rows need flipping
346346
DType temp;
347347

348348
for (int col = 0; col < M; ++col) {
349-
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
349+
NM_SWAP(matrix[interchange*M + col], matrix[k*M + col], temp);
350350
}
351351
}
352352

@@ -360,7 +360,7 @@ namespace nm {
360360
DType pivot = matrix[k * (M + 1)];
361361
matrix[k * (M + 1)] = (DType)(1); // set diagonal as 1 for in-place inversion
362362

363-
for (int col = 0; col < M; ++col) {
363+
for (int col = 0; col < M; ++col) {
364364
// divide each element in the kth row with the pivot
365365
matrix[k*M + col] = matrix[k*M + col] / pivot;
366366
}
@@ -369,7 +369,7 @@ namespace nm {
369369
if (kk == k) continue;
370370

371371
DType dum = matrix[k + M*kk];
372-
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
372+
matrix[k + M*kk] = (DType)(0); // prepare for inplace inversion
373373
for (int col = 0; col < M; ++col) {
374374
matrix[M*kk + col] = matrix[M*kk + col] - matrix[M*k + col] * dum;
375375
}
@@ -384,7 +384,7 @@ namespace nm {
384384

385385
for (int row = 0; row < M; ++row) {
386386
NM_SWAP(matrix[row * M + row_index[k]], matrix[row * M + col_index[k]],
387-
temp);
387+
temp);
388388
}
389389
}
390390
}
@@ -410,14 +410,14 @@ namespace nm {
410410
DType sum_of_squares, *p_row, *psubdiag, *p_a, scale, innerproduct;
411411
int i, k, col;
412412

413-
// For each column use a Householder transformation to zero all entries
413+
// For each column use a Householder transformation to zero all entries
414414
// below the subdiagonal.
415-
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
415+
for (psubdiag = a + nrows, col = 0; col < nrows - 2; psubdiag += nrows + 1,
416416
col++) {
417417
// Calculate the signed square root of the sum of squares of the
418418
// elements below the diagonal.
419419

420-
for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
420+
for (p_a = psubdiag, sum_of_squares = 0.0, i = col + 1; i < nrows;
421421
p_a += nrows, i++) {
422422
sum_of_squares += *p_a * *p_a;
423423
}
@@ -447,7 +447,7 @@ namespace nm {
447447
*p_a -= u[k] * innerproduct;
448448
}
449449
}
450-
450+
451451
// Postmultiply QA by Q
452452
for (p_row = a, i = 0; i < nrows; p_row += nrows, i++) {
453453
for (innerproduct = 0.0, k = col + 1; k < nrows; k++) {
@@ -465,15 +465,15 @@ namespace nm {
465465
}
466466

467467
void raise_not_invertible_error() {
468-
rb_raise(nm_eNotInvertibleError,
468+
rb_raise(nm_eNotInvertibleError,
469469
"matrix must have non-zero determinant to be invertible (not getting this error does not mean matrix is invertible if you're dealing with floating points)");
470470
}
471471

472472
/*
473473
* Calculate the exact inverse for a dense matrix (A [elements]) of size 2 or 3. Places the result in B_elements.
474474
*/
475475
template <typename DType>
476-
void inverse_exact_from_dense(const int M, const void* A_elements,
476+
void inverse_exact_from_dense(const int M, const void* A_elements,
477477
const int lda, void* B_elements, const int ldb) {
478478

479479
const DType* A = reinterpret_cast<const DType*>(A_elements);
@@ -485,7 +485,7 @@ namespace nm {
485485
B[0] = A[lda+1] / det;
486486
B[1] = -A[1] / det;
487487
B[ldb] = -A[lda] / det;
488-
B[ldb+1] = -A[0] / det;
488+
B[ldb+1] = A[0] / det;
489489

490490
} else if (M == 3) {
491491
// Calculate the exact determinant.
@@ -510,7 +510,7 @@ namespace nm {
510510
}
511511

512512
template <typename DType>
513-
void inverse_exact_from_yale(const int M, const YALE_STORAGE* storage,
513+
void inverse_exact_from_yale(const int M, const YALE_STORAGE* storage,
514514
const int lda, YALE_STORAGE* inverse, const int ldb) {
515515

516516
// inverse is a clone of storage
@@ -524,18 +524,18 @@ namespace nm {
524524

525525
if (M == 2) {
526526
IType ndnz = ija[2] - ija[0];
527-
if (ndnz == 2) {
528-
det = a[0] * a[1] - a[col_pos] * a[col_pos+1];
527+
if (ndnz == 2) {
528+
det = a[0] * a[1] - a[col_pos] * a[col_pos+1];
529529
}
530530
else { det = a[0] * a[1]; }
531531
if (det == 0) { raise_not_invertible_error(); }
532532
b[0] = a[1] / det;
533-
b[1] = -a[0] / det;
534-
if (ndnz == 2) {
533+
b[1] = a[0] / det;
534+
if (ndnz == 2) {
535535
b[col_pos] = -a[col_pos] / det;
536536
b[col_pos+1] = -a[col_pos+1] / det;
537537
}
538-
else if (ndnz == 1) {
538+
else if (ndnz == 1) {
539539
b[col_pos] = -a[col_pos] / det;
540540
}
541541

@@ -561,7 +561,7 @@ namespace nm {
561561
rb_raise(rb_eArgError, "some value in IJA is incorrect!");
562562
}
563563
}
564-
det =
564+
det =
565565
A[0] * A[lda+1] * A[2*lda+2] + A[1] * A[lda+2] * A[2*lda] + A[2] * A[lda] * A[2*lda+1]
566566
- A[0] * A[lda+2] * A[2*lda+1] - A[1] * A[lda] * A[2*lda+2] - A[2] * A[lda+1] * A[2*lda];
567567
if (det == 0) { raise_not_invertible_error(); }
@@ -1267,9 +1267,9 @@ static VALUE nm_clapack_laswp(VALUE self, VALUE n, VALUE a, VALUE lda, VALUE k1,
12671267
/*
12681268
* C accessor for calculating an exact determinant. Dense matrix version.
12691269
*/
1270-
void nm_math_det_exact_from_dense(const int M, const void* elements, const int lda,
1270+
void nm_math_det_exact_from_dense(const int M, const void* elements, const int lda,
12711271
nm::dtype_t dtype, void* result) {
1272-
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::det_exact_from_dense, void, const int M,
1272+
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::det_exact_from_dense, void, const int M,
12731273
const void* A_elements, const int lda, void* result_arg);
12741274

12751275
ttable[dtype](M, elements, lda, result);
@@ -1278,27 +1278,27 @@ void nm_math_det_exact_from_dense(const int M, const void* elements, const int l
12781278
/*
12791279
* C accessor for calculating an exact determinant. Yale matrix version.
12801280
*/
1281-
void nm_math_det_exact_from_yale(const int M, const YALE_STORAGE* storage, const int lda,
1281+
void nm_math_det_exact_from_yale(const int M, const YALE_STORAGE* storage, const int lda,
12821282
nm::dtype_t dtype, void* result) {
1283-
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::det_exact_from_yale, void, const int M,
1283+
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::det_exact_from_yale, void, const int M,
12841284
const YALE_STORAGE* storage, const int lda, void* result_arg);
12851285

12861286
ttable[dtype](M, storage, lda, result);
12871287
}
12881288

12891289
/*
1290-
* C accessor for solving a system of linear equations.
1290+
* C accessor for solving a system of linear equations.
12911291
*/
12921292
void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv) {
12931293
int* pivot = new int[RARRAY_LEN(ipiv)];
12941294

12951295
for (int i = 0; i < RARRAY_LEN(ipiv); ++i) {
1296-
pivot[i] = FIX2INT(rb_ary_entry(ipiv, i));
1296+
pivot[i] = FIX2INT(rb_ary_entry(ipiv, i));
12971297
}
12981298

12991299
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::solve, void, const int, const void*, const void*, void*, const int*);
13001300

1301-
ttable[NM_DTYPE(x)](NM_SHAPE0(b), NM_STORAGE_DENSE(lu)->elements,
1301+
ttable[NM_DTYPE(x)](NM_SHAPE0(b), NM_STORAGE_DENSE(lu)->elements,
13021302
NM_STORAGE_DENSE(b)->elements, NM_STORAGE_DENSE(x)->elements, pivot);
13031303
}
13041304

@@ -1313,7 +1313,7 @@ void nm_math_hessenberg(VALUE a) {
13131313
NULL, NULL, // does not support Complex
13141314
NULL // no support for Ruby Object
13151315
};
1316-
1316+
13171317
ttable[NM_DTYPE(a)](NM_SHAPE0(a), NM_STORAGE_DENSE(a)->elements);
13181318
}
13191319
/*
@@ -1328,10 +1328,10 @@ void nm_math_inverse(const int M, void* a_elements, nm::dtype_t dtype) {
13281328
/*
13291329
* C accessor for calculating an exact inverse. Dense matrix version.
13301330
*/
1331-
void nm_math_inverse_exact_from_dense(const int M, const void* A_elements,
1331+
void nm_math_inverse_exact_from_dense(const int M, const void* A_elements,
13321332
const int lda, void* B_elements, const int ldb, nm::dtype_t dtype) {
13331333

1334-
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::inverse_exact_from_dense, void,
1334+
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::inverse_exact_from_dense, void,
13351335
const int, const void*, const int, void*, const int);
13361336

13371337
ttable[dtype](M, A_elements, lda, B_elements, ldb);
@@ -1340,10 +1340,10 @@ void nm_math_inverse_exact_from_dense(const int M, const void* A_elements,
13401340
/*
13411341
* C accessor for calculating an exact inverse. Yale matrix version.
13421342
*/
1343-
void nm_math_inverse_exact_from_yale(const int M, const YALE_STORAGE* storage,
1343+
void nm_math_inverse_exact_from_yale(const int M, const YALE_STORAGE* storage,
13441344
const int lda, YALE_STORAGE* inverse, const int ldb, nm::dtype_t dtype) {
13451345

1346-
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::inverse_exact_from_yale, void,
1346+
NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::math::inverse_exact_from_yale, void,
13471347
const int, const YALE_STORAGE*, const int, YALE_STORAGE*, const int);
13481348

13491349
ttable[dtype](M, storage, lda, inverse, ldb);

lib/nmatrix/math.rb

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,58 @@ def invert
112112
end
113113
alias :inverse :invert
114114

115+
# call-seq:
116+
# exact_inverse! -> NMatrix
117+
#
118+
# Calulates inverse_exact of a matrix of size 2 or 3.
119+
# Only works on dense matrices.
120+
#
121+
# * *Raises* :
122+
# - +DataTypeError+ -> cannot invert an integer matrix in-place.
123+
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3 #
124+
def exact_inverse!
125+
raise(ShapeError, "Cannot invert non-square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
126+
raise(DataTypeError, "Cannot invert an integer matrix in-place") if self.integer_dtype?
127+
#No internal implementation of getri, so use this other function
128+
n = self.shape[0]
129+
if n>3
130+
raise(NotImplementedError, "Cannot find exact inverse of matrix of size greater than 3")
131+
else
132+
clond=self.clone
133+
__inverse_exact__(clond, n, n)
134+
end
135+
end
136+
137+
#
138+
# call-seq:
139+
# exact_inverse -> NMatrix
140+
#
141+
# Make a copy of the matrix, then invert using exact_inverse
142+
#
143+
# * *Returns* :
144+
# - A dense NMatrix. Will be the same type as the input NMatrix,
145+
# except if the input is an integral dtype, in which case it will be a
146+
# :float64 NMatrix.
147+
#
148+
# * *Raises* :
149+
# - +StorageTypeError+ -> only implemented on dense matrices.
150+
# - +ShapeError+ -> matrix must be square.
151+
# - +NotImplementedError+ -> cannot find exact inverse of matrix with size greater than 3
152+
#
153+
def exact_inverse
154+
#write this in terms of exact_inverse! so plugins will only have to overwrite
155+
#exact_inverse! and not exact_inverse
156+
if self.integer_dtype?
157+
cloned = self.cast(dtype: :float64)
158+
cloned.exact_inverse!
159+
else
160+
cloned = self.clone
161+
cloned.exact_inverse!
162+
end
163+
end
164+
alias :invert_exactly :exact_inverse
165+
166+
115167

116168
#
117169
# call-seq:

spec/00_nmatrix_spec.rb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373

7474
e = NMatrix.new(2, [3,1,2,1], stype: :dense, dtype: :int64)
7575
inversed = e.method(:__inverse_exact__).call(e.clone, 2, 2)
76-
f = NMatrix.new(2, [1,-1,-2,-3], stype: :dense, dtype: :int64)
76+
f = NMatrix.new(2, [1,-1,-2,3], stype: :dense, dtype: :int64)
7777
expect(inversed).to eq(f)
7878
end
7979

@@ -91,7 +91,7 @@
9191

9292
e = NMatrix.new(2, [3,1,2,1], stype: :yale, dtype: :int64)
9393
inversed = e.method(:__inverse_exact__).call(e.clone, 2, 2)
94-
f = NMatrix.new(2, [1,-1,-2,-3], stype: :yale, dtype: :int64)
94+
f = NMatrix.new(2, [1,-1,-2,3], stype: :yale, dtype: :int64)
9595
expect(inversed).to eq(f)
9696
end
9797

@@ -104,7 +104,7 @@
104104

105105
c = NMatrix.new(2, [3,1,2,1], stype: :list, dtype: :int64)
106106
inversed = c.method(:__inverse_exact__).call(c.clone, 2, 2)
107-
d = NMatrix.new(2, [1,-1,-2,-3], stype: :list, dtype: :int64)
107+
d = NMatrix.new(2, [1,-1,-2,3], stype: :list, dtype: :int64)
108108
expect(inversed).to eq(d)
109109
end
110110

spec/math_spec.rb

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,22 @@
495495

496496
expect(a.invert).to be_within(err).of(b)
497497
end
498+
499+
it "should correctly find exact inverse" do
500+
pending("not yet implemented for NMatrix-JRuby") if jruby?
501+
a = NMatrix.new(:dense, 3, [1,2,3,0,1,4,5,6,0], dtype)
502+
b = NMatrix.new(:dense, 3, [-24,18,5,20,-15,-4,-5,4,1], dtype)
503+
504+
expect(a.exact_inverse).to be_within(err).of(b)
505+
end
506+
507+
it "should correctly find exact inverse" do
508+
pending("not yet implemented for NMatrix-JRuby") if jruby?
509+
a = NMatrix.new(:dense, 2, [1,3,3,8], dtype)
510+
b = NMatrix.new(:dense, 2, [-8,3,3,-1], dtype)
511+
512+
expect(a.exact_inverse).to be_within(err).of(b)
513+
end
498514
end
499515
end
500516

0 commit comments

Comments
 (0)