@@ -531,6 +531,77 @@ def solve(b, opts = {})
531531 end
532532 end
533533
534+ #
535+ # call-seq:
536+ # least_squares(b) -> NMatrix
537+ # least_squares(b, tolerance: 10e-10) -> NMatrix
538+ #
539+ # Provides the linear least squares approximation of an under-determined system
540+ # using QR factorization provided that the matrix is not rank-deficient.
541+ #
542+ # Only works for dense matrices.
543+ #
544+ # * *Arguments* :
545+ # - +b+ -> The solution column vector NMatrix of A * X = b.
546+ # - +tolerance:+ -> Absolute tolerance to check if a diagonal element in A = QR is near 0
547+ #
548+ # * *Returns* :
549+ # - NMatrix that is a column vector with the LLS solution
550+ #
551+ # * *Raises* :
552+ # - +ArgumentError+ -> least squares approximation only works for non-complex types
553+ # - +ShapeError+ -> system must be under-determined ( rows > columns )
554+ #
555+ # Examples :-
556+ #
557+ # a = NMatrix.new([3,2], [2.0, 0, -1, 1, 0, 2])
558+ #
559+ # b = NMatrix.new([3,1], [1.0, 0, -1])
560+ #
561+ # a.least_squares(b)
562+ # =>[
563+ # [ 0.33333333333333326 ]
564+ # [ -0.3333333333333334 ]
565+ # ]
566+ #
567+ def least_squares ( b , tolerance : 10e-6 )
568+ raise ( ArgumentError , "least squares approximation only works for non-complex types" ) if
569+ self . complex_dtype?
570+
571+ rows , columns = self . shape
572+
573+ raise ( ShapeError , "system must be under-determined ( rows > columns )" ) unless
574+ rows > columns
575+
576+ #Perform economical QR factorization
577+ r = self . clone
578+ tau = r . geqrf!
579+ q_transpose_b = r . ormqr ( tau , :left , :transpose , b )
580+
581+ #Obtain R from geqrf! intermediate
582+ r [ 0 ...columns , 0 ...columns ] . upper_triangle!
583+ r [ columns ...rows , 0 ...columns ] = 0
584+
585+ diagonal = r . diagonal
586+
587+ raise ( ArgumentError , "rank deficient matrix" ) if diagonal . any? { |x | x == 0 }
588+
589+ if diagonal . any? { |x | x . abs < tolerance }
590+ warn "warning: A diagonal element of R in A = QR is close to zero ;" <<
591+ " indicates a possible loss of precision"
592+ end
593+
594+ # Transform the system A * X = B to R1 * X = B2 where B2 = Q1_t * B
595+ r1 = r [ 0 ...columns , 0 ...columns ]
596+ b2 = q_transpose_b [ 0 ...columns ]
597+
598+ nrhs = b2 . shape [ 1 ]
599+
600+ #Solve the upper triangular system
601+ NMatrix ::BLAS ::cblas_trsm ( :row , :left , :upper , false , :nounit , r1 . shape [ 0 ] , nrhs , 1.0 , r1 , r1 . shape [ 0 ] , b2 , nrhs )
602+ b2
603+ end
604+
534605 #
535606 # call-seq:
536607 # gesvd! -> [u, sigma, v_transpose]
0 commit comments