Skip to content

Commit 00b8699

Browse files
gtambatranslunar
authored andcommitted
Add #least_squares for least squares approximation (#539)
* clone input parameter to ormqr and unmqr to prevent it from being overwritten * Add #least_squares for linear least squares solution using QR factorization
1 parent f53cd76 commit 00b8699

File tree

3 files changed

+101
-2
lines changed

3 files changed

+101
-2
lines changed

lib/nmatrix/lapacke.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def ormqr(tau, side=:left, transpose=false, c=nil)
303303

304304

305305
#Default behaviour produces Q * I = Q if c is not supplied.
306-
result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
306+
result = c ? c.clone : NMatrix.identity(self.shape[0], dtype: self.dtype)
307307
NMatrix::LAPACK::lapacke_ormqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
308308

309309
result
@@ -343,7 +343,7 @@ def unmqr(tau, side=:left, transpose=false, c=nil)
343343
raise(TypeError, "c must have the same dtype as the calling NMatrix") if c and c.dtype != self.dtype
344344

345345
#Default behaviour produces Q * I = Q if c is not supplied.
346-
result = c ? c : NMatrix.identity(self.shape[0], dtype: self.dtype)
346+
result = c ? c.clone : NMatrix.identity(self.shape[0], dtype: self.dtype)
347347
NMatrix::LAPACK::lapacke_unmqr(:row, side, transpose, result.shape[0], result.shape[1], tau.shape[0], self, self.shape[1], tau, result, result.shape[1])
348348

349349
result

lib/nmatrix/math.rb

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,77 @@ def solve(b, opts = {})
531531
end
532532
end
533533

534+
#
535+
# call-seq:
536+
# least_squares(b) -> NMatrix
537+
# least_squares(b, tolerance: 10e-10) -> NMatrix
538+
#
539+
# Provides the linear least squares approximation of an under-determined system
540+
# using QR factorization provided that the matrix is not rank-deficient.
541+
#
542+
# Only works for dense matrices.
543+
#
544+
# * *Arguments* :
545+
# - +b+ -> The solution column vector NMatrix of A * X = b.
546+
# - +tolerance:+ -> Absolute tolerance to check if a diagonal element in A = QR is near 0
547+
#
548+
# * *Returns* :
549+
# - NMatrix that is a column vector with the LLS solution
550+
#
551+
# * *Raises* :
552+
# - +ArgumentError+ -> least squares approximation only works for non-complex types
553+
# - +ShapeError+ -> system must be under-determined ( rows > columns )
554+
#
555+
# Examples :-
556+
#
557+
# a = NMatrix.new([3,2], [2.0, 0, -1, 1, 0, 2])
558+
#
559+
# b = NMatrix.new([3,1], [1.0, 0, -1])
560+
#
561+
# a.least_squares(b)
562+
# =>[
563+
# [ 0.33333333333333326 ]
564+
# [ -0.3333333333333334 ]
565+
# ]
566+
#
567+
def least_squares(b, tolerance: 10e-6)
568+
raise(ArgumentError, "least squares approximation only works for non-complex types") if
569+
self.complex_dtype?
570+
571+
rows, columns = self.shape
572+
573+
raise(ShapeError, "system must be under-determined ( rows > columns )") unless
574+
rows > columns
575+
576+
#Perform economical QR factorization
577+
r = self.clone
578+
tau = r.geqrf!
579+
q_transpose_b = r.ormqr(tau, :left, :transpose, b)
580+
581+
#Obtain R from geqrf! intermediate
582+
r[0...columns, 0...columns].upper_triangle!
583+
r[columns...rows, 0...columns] = 0
584+
585+
diagonal = r.diagonal
586+
587+
raise(ArgumentError, "rank deficient matrix") if diagonal.any? { |x| x == 0 }
588+
589+
if diagonal.any? { |x| x.abs < tolerance }
590+
warn "warning: A diagonal element of R in A = QR is close to zero ;" <<
591+
" indicates a possible loss of precision"
592+
end
593+
594+
# Transform the system A * X = B to R1 * X = B2 where B2 = Q1_t * B
595+
r1 = r[0...columns, 0...columns]
596+
b2 = q_transpose_b[0...columns]
597+
598+
nrhs = b2.shape[1]
599+
600+
#Solve the upper triangular system
601+
NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, r1.shape[0], nrhs, 1.0, r1, r1.shape[0], b2, nrhs)
602+
b2
603+
end
604+
534605
#
535606
# call-seq:
536607
# gesvd! -> [u, sigma, v_transpose]

spec/math_spec.rb

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,34 @@
849849
end
850850
end
851851

852+
context "#least_squares" do
853+
it "finds the least squares approximation to the equation A * X = B" do
854+
a = NMatrix.new([3,2], [2.0, 0, -1, 1, 0, 2])
855+
b = NMatrix.new([3,1], [1.0, 0, -1])
856+
solution = NMatrix.new([2,1], [1.0 / 3 , -1.0 / 3], dtype: :float64)
857+
858+
begin
859+
least_squares = a.least_squares(b)
860+
expect(least_squares).to be_within(0.0001).of solution
861+
rescue NotImplementedError
862+
"Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
863+
end
864+
end
865+
866+
it "finds the least squares approximation to the equation A * X = B with high tolerance" do
867+
a = NMatrix.new([4,2], [1.0, 1, 1, 2, 1, 3,1,4])
868+
b = NMatrix.new([4,1], [6.0, 5, 7, 10])
869+
solution = NMatrix.new([2,1], [3.5 , 1.4], dtype: :float64)
870+
871+
begin
872+
least_squares = a.least_squares(b, tolerance: 10e-5)
873+
expect(least_squares).to be_within(0.0001).of solution
874+
rescue NotImplementedError
875+
"Suppressing a NotImplementedError when the lapacke or atlas plugin is not available"
876+
end
877+
end
878+
end
879+
852880
context "#hessenberg" do
853881
FLOAT_DTYPES.each do |dtype|
854882
context dtype do

0 commit comments

Comments
 (0)