Skip to content

Commit d2439c0

Browse files
authored
Merge pull request #564 from saru95/master
Added NMatrix#cumsum alias for NMatrix#sum
2 parents 1f5b0f8 + 24eb111 commit d2439c0

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

lib/nmatrix/math.rb

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -536,15 +536,15 @@ def solve(b, opts = {})
536536
# least_squares(b) -> NMatrix
537537
# least_squares(b, tolerance: 10e-10) -> NMatrix
538538
#
539-
# Provides the linear least squares approximation of an under-determined system
539+
# Provides the linear least squares approximation of an under-determined system
540540
# using QR factorization provided that the matrix is not rank-deficient.
541541
#
542542
# Only works for dense matrices.
543543
#
544544
# * *Arguments* :
545545
# - +b+ -> The solution column vector NMatrix of A * X = b.
546-
# - +tolerance:+ -> Absolute tolerance to check if a diagonal element in A = QR is near 0
547-
#
546+
# - +tolerance:+ -> Absolute tolerance to check if a diagonal element in A = QR is near 0
547+
#
548548
# * *Returns* :
549549
# - NMatrix that is a column vector with the LLS solution
550550
#
@@ -554,8 +554,8 @@ def solve(b, opts = {})
554554
#
555555
# Examples :-
556556
#
557-
# a = NMatrix.new([3,2], [2.0, 0, -1, 1, 0, 2])
558-
#
557+
# a = NMatrix.new([3,2], [2.0, 0, -1, 1, 0, 2])
558+
#
559559
# b = NMatrix.new([3,1], [1.0, 0, -1])
560560
#
561561
# a.least_squares(b)
@@ -564,30 +564,30 @@ def solve(b, opts = {})
564564
# [ -0.3333333333333334 ]
565565
# ]
566566
#
567-
def least_squares(b, tolerance: 10e-6)
568-
raise(ArgumentError, "least squares approximation only works for non-complex types") if
567+
def least_squares(b, tolerance: 10e-6)
568+
raise(ArgumentError, "least squares approximation only works for non-complex types") if
569569
self.complex_dtype?
570-
570+
571571
rows, columns = self.shape
572572

573-
raise(ShapeError, "system must be under-determined ( rows > columns )") unless
573+
raise(ShapeError, "system must be under-determined ( rows > columns )") unless
574574
rows > columns
575-
575+
576576
#Perform economical QR factorization
577577
r = self.clone
578578
tau = r.geqrf!
579579
q_transpose_b = r.ormqr(tau, :left, :transpose, b)
580-
580+
581581
#Obtain R from geqrf! intermediate
582582
r[0...columns, 0...columns].upper_triangle!
583583
r[columns...rows, 0...columns] = 0
584-
584+
585585
diagonal = r.diagonal
586586

587587
raise(ArgumentError, "rank deficient matrix") if diagonal.any? { |x| x == 0 }
588-
588+
589589
if diagonal.any? { |x| x.abs < tolerance }
590-
warn "warning: A diagonal element of R in A = QR is close to zero ;" <<
590+
warn "warning: A diagonal element of R in A = QR is close to zero ;" <<
591591
" indicates a possible loss of precision"
592592
end
593593

@@ -981,7 +981,9 @@ def mean(dimen=0)
981981
##
982982
# call-seq:
983983
# sum() -> NMatrix
984+
# cumsum() -> NMatrix
984985
# sum(dimen) -> NMatrix
986+
# cumsum(dimen) -> NMatrix
985987
#
986988
# Calculates the sum along the specified dimension.
987989
#
@@ -991,7 +993,7 @@ def sum(dimen=0)
991993
sum + sub_mat
992994
end
993995
end
994-
996+
alias :cumsum :sum
995997

996998
##
997999
# call-seq:

0 commit comments

Comments
 (0)