@@ -16,10 +16,28 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
1616
1717 contains
1818
19+ elemental subroutine handle_gelsd_info(info,lda,n,ldb,nrhs,err)
20+ integer(ilp), intent(in) :: info,lda,n,ldb,nrhs
21+ type(linalg_state_type), intent(out) :: err
22+
23+ ! Process output
24+ select case (info)
25+ case (0)
26+ ! Success
27+ case (:-1)
28+ err = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size a=',[lda,n], &
29+ ', b=',[ldb,nrhs])
30+ case (1:)
31+ err = linalg_state_type(this,LINALG_ERROR,'SVD did not converge.')
32+ case default
33+ err = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
34+
35+ end subroutine handle_gelsd_info
36+
1937 #:for rk,rt,ri in RC_KINDS_TYPES
2038 #:if rk!="xdp"
21- ! Workspace needed by gesv
22- elemental subroutine ${ri}$gesv_space (m,n,nrhs,lrwork,liwork,lcwork)
39+ ! Workspace needed by gelsd
40+ elemental subroutine ${ri}$gelsd_space (m,n,nrhs,lrwork,liwork,lcwork)
2341 integer(ilp), intent(in) :: m,n,nrhs
2442 integer(ilp), intent(out) :: lrwork,liwork,lcwork
2543
@@ -53,7 +71,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
5371 lcwork = ceiling(1.25*lcwork,kind=ilp)
5472 liwork = ceiling(1.25*liwork,kind=ilp)
5573
56- end subroutine ${ri}$gesv_space
74+ end subroutine ${ri}$gelsd_space
5775
5876 #:endif
5977 #:endfor
@@ -93,33 +111,87 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
93111 !> [optional] state return flag. On error if not requested, the code will stop
94112 type(linalg_state_type), optional, intent(out) :: err
95113 !> Result array/matrix x[n] or x[n,nrhs]
96- ${rt}$, allocatable, target :: x${nd}$
114+ ${rt}$, allocatable, target :: x${nd}$
115+
116+ ! Initialize solution with the shape of the rhs
117+ allocate(x,mold=b)
118+
119+ call stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x,&
120+ cond=cond,overwrite_a=overwrite_a,rank=rank,err=err)
121+
122+ end function stdlib_linalg_${ri}$_lstsq_${ndsuf}$
123+
124+ ! Compute the least-squares solution to a real system of linear equations Ax = b
125+ module subroutine stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x, &
126+ real_storage,int_storage#{if rt.startswith('c')}#,cmpl_storage#{endif}#,cond,overwrite_a,rank,err)
127+
128+ !!### Summary
129+ !! Compute least-squares solution to a real system of linear equations \( Ax = b \)
130+ !!
131+ !!### Description
132+ !!
133+ !! This function computes the least-squares solution of a linear matrix problem.
134+ !!
135+ !! param: a Input matrix of size [m,n].
136+ !! param: b Right-hand-side vector of size [n] or matrix of size [n,nrhs].
137+ !! param: cond [optional] Real input threshold indicating that singular values `s_i <= cond*maxval(s)`
138+ !! do not contribute to the matrix rank.
139+ !! param: overwrite_a [optional] Flag indicating if the input matrix can be overwritten.
140+ !! param: rank [optional] integer flag returning matrix rank.
141+ !! param: err [optional] State return flag.
142+ !! return: x Solution vector of size [n] or solution matrix of size [n,nrhs].
143+ !!
144+ !> Input matrix a[n,n]
145+ ${rt}$, intent(inout), target :: a(:,:)
146+ !> Right hand side vector or array, b[n] or b[n,nrhs]
147+ ${rt}$, intent(in) :: b${nd}$
148+ !> Result array/matrix x[n] or x[n,nrhs]
149+ ${rt}$, intent(inout), contiguous, target :: x${nd}$
150+ !> [optional] real working storage space
151+ real(${rk}$), optional, intent(inout), target :: real_storage(:)
152+ !> [optional] integer working storage space
153+ integer(ilp), optional, intent(inout), target :: int_storage(:)
154+ #:if rt.startswith('c')
155+ !> [optional] complex working storage space
156+ ${rt}$, optional, intent(inout), target :: cmpl_storage(:)
157+ #:endif
158+ !> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
159+ real(${rk}$), optional, intent(in) :: cond
160+ !> [optional] Can A,b data be overwritten and destroyed?
161+ logical(lk), optional, intent(in) :: overwrite_a
162+ !> [optional] Return rank of A
163+ integer(ilp), optional, intent(out) :: rank
164+ !> [optional] state return flag. On error if not requested, the code will stop
165+ type(linalg_state_type), optional, intent(out) :: err
97166
98167 !! Local variables
99168 type(linalg_state_type) :: err0
100- integer(ilp) :: m,n,lda,ldb,nrhs,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
101- integer(ilp), allocatable :: iwork(:)
169+ integer(ilp) :: m,n,lda,ldb,nrhs,ldx,nrhsx,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
170+ integer(ilp) :: nrs,nis,ncs
171+ integer(ilp), pointer :: iwork(:)
102172 logical(lk) :: copy_a
103173 real(${rk}$) :: acond,rcond
104- real(${rk}$), allocatable :: singular(:),rwork(:)
105- ${rt}$ , pointer :: xmat(:,:),amat(:, :)
106- ${rt}$, allocatable :: cwork(:)
174+ real(${rk}$), allocatable :: singular(:)
175+ real(${rk}$) , pointer :: rwork( :)
176+ ${rt}$, pointer :: xmat(:,:),amat(:,:), cwork(:)
107177
108178 ! Problem sizes
109179 m = size(a,1,kind=ilp)
110180 lda = size(a,1,kind=ilp)
111181 n = size(a,2,kind=ilp)
112182 ldb = size(b,1,kind=ilp)
113183 nrhs = size(b ,kind=ilp)/ldb
184+ ldx = size(x,1,kind=ilp)
185+ nrhsx = size(x ,kind=ilp)/ldx
114186 mnmin = min(m,n)
115187 mnmax = max(m,n)
116188
117- if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m) then
189+ if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m .or. ldx/=m ) then
118190 err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
119- 'b=',[ldb,nrhs])
120- allocate(x${nde}$)
191+ 'b=',[ldb,nrhs],' x=',[ldx,nrhsx])
121192 call linalg_error_handling(err0,err)
122193 if (present(rank)) rank = 0
194+ return
123195 end if
124196
125197 ! Can A be overwritten? By default, do not overwrite
@@ -137,7 +209,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
137209 endif
138210
139211 ! Initialize solution with the rhs
140- allocate(x,source=b)
212+ x = b
141213 xmat(1:n,1:nrhs) => x
142214
143215 ! Singular values array (in decreasing order)
@@ -153,44 +225,71 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
153225 endif
154226 if (rcond<0) rcond = epsilon(0.0_${rk}$)*mnmax
155227
156- ! Allocate working space
157- call ${ri}$gesv_space(m,n,nrhs,lrwork,liwork,lcwork)
158- #:if rt.startswith('complex')
159- allocate(rwork(lrwork),cwork(lcwork),iwork(liwork))
160- #:else
161- allocate(rwork(lrwork),iwork(liwork))
162- #:endif
228+ ! Get working space size
229+ call ${ri}$gelsd_space(m,n,nrhs,lrwork,liwork,lcwork)
230+
231+ ! Real working space
232+ if (present(real_storage)) then
233+ rwork => real_storage
234+ else
235+ allocate(rwork(lrwork))
236+ endif
237+ nrs = size(rwork,kind=ilp)
163238
164- ! Solve system using singular value decomposition
165- #:if rt.startswith('complex')
166- call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank,cwork,lrwork,rwork,iwork,info)
167- #:else
168- call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank,rwork,lrwork,iwork,info)
169- #:endif
239+ ! Integer working space
240+ if (present(int_storage)) then
241+ iwork => int_storage
242+ else
243+ allocate(iwork(liwork))
244+ endif
245+ nis = size(iwork,kind=ilp)
246+
247+ #:if rt.startswith('complex')
248+ ! Complex working space
249+ if (present(cmpl_storage)) then
250+ cwork => cmpl_storage
251+ else
252+ allocate(cwork(lcwork))
253+ endif
254+ ncs = size(iwork,kind=ilp)
255+ #:endif
256+
257+ if (nrs<lrwork .or. nis<liwork#{if rt.startswith('c')}# .or. ncs<lcwork#{endif}#) then
258+ ! Halt on insufficient space
259+ err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'insufficient working space: ',&
260+ 'real=',nrs,' should be >=',lrwork, &
261+ ', int=',nis,' should be >=',liwork &
262+ #{if rt.startswith('complex')}#,', cmplx=',ncs,' should be >=',lcwork#{endif}#)
263+
264+ else
170265
266+ ! Solve system using singular value decomposition
267+ call gelsd(m,n,nrhs,amat,lda,xmat,ldb,singular,rcond,arank, &
268+ #:if rt.startswith('complex')
269+ cwork,nrs,rwork,iwork,info)
270+ #:else
271+ rwork,nrs,iwork,info)
272+ #:endif
273+
274+ endif
275+
171276 ! The condition number of A in the 2-norm = S(1)/S(min(m,n)).
172277 acond = singular(1)/singular(mnmin)
173278
174279 ! Process output
175- select case (info)
176- case (0)
177- ! Success
178- case (:-1)
179- err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid problem size a=',[lda,n], &
180- ', b=',[ldb,nrhs])
181- case (1:)
182- err0 = linalg_state_type(this,LINALG_ERROR,'SVD did not converge.')
183- case default
184- err0 = linalg_state_type(this,LINALG_INTERNAL_ERROR,'catastrophic error')
185- end select
186-
187- if (copy_a) deallocate(amat)
280+ call handle_gelsd_info(info,lda,n,ldb,nrhs,err0)
188281
189282 ! Process output and return
190- call linalg_error_handling(err0,err)
283+ 1 if (copy_a) deallocate(amat)
191284 if (present(rank)) rank = arank
285+ if (.not.present(real_storage)) deallocate(rwork)
286+ if (.not.present(int_storage)) deallocate(iwork)
287+ #:if rt.startswith('complex')
288+ if (.not.present(cmpl_storage)) deallocate(cwork)
289+ #:endif
290+ call linalg_error_handling(err0,err)
192291
193- end function stdlib_linalg_${ri}$_lstsq_ ${ndsuf}$
292+ end subroutine stdlib_linalg_${ri}$_solve_lstsq_ ${ndsuf}$
194293
195294 #:endif
196295 #:endfor
0 commit comments