@@ -9,44 +9,37 @@ module test_linalg_svd
99 use stdlib_linalg_state, only: linalg_state_type
1010
1111 implicit none (type,external)
12+
13+ public :: test_svd
1214
1315 contains
1416
15- !> SVD tests
16- subroutine test_svd(error)
17- logical, intent(out) :: error
18-
19- real :: t0,t1
20-
21- call cpu_time(t0)
17+ !> Solve several SVD problems
18+ subroutine test_svd(tests)
19+ !> Collection of tests
20+ type(unittest_type), allocatable, intent(out) :: tests(:)
21+
22+ allocate(tests(0))
2223
2324 #:for rk,rt,ri in REAL_KINDS_TYPES
2425 #:if rk!="xdp"
25- call test_svd_${ri}$(error)
26- if (error) return
26+ tests = [tests,new_unittest("test_svd_${ri}$",test_svd_${ri}$)]
2727 #:endif
2828 #:endfor
2929
3030 #:for ck,ct,ci in CMPLX_KINDS_TYPES
3131 #:if ck!="xdp"
32- call test_complex_svd_${ci}$(error)
33- if (error) return
32+ tests = [tests,new_unittest("test_complex_svd_${ci}$",test_complex_svd_${ci}$)]
3433 #:endif
3534 #:endfor
3635
37- call cpu_time(t1)
38-
39- print 1, 1000*(t1-t0), merge('SUCCESS','ERROR ',.not.error)
40-
41- 1 format('SVD tests completed in ',f9.4,' milliseconds, result=',a)
42-
4336 end subroutine test_svd
4437
4538 !> Real matrix svd
4639 #:for rk,rt,ri in REAL_KINDS_TYPES
4740 #:if rk!="xdp"
4841 subroutine test_svd_${ri}$(error)
49- logical, intent(out) :: error
42+ type(error_type), allocatable, intent(out) :: error
5043
5144 !> Reference solution
5245 ${rt}$, parameter :: tol = sqrt(epsilon(0.0_${rk}$))
@@ -63,6 +56,7 @@ module test_linalg_svd
6356 0.0_${rk}$,4*rsqrt18,-third],[3,3])
6457
6558 !> Local variables
59+ character(:), allocatable :: test
6660 type(linalg_state_type) :: state
6761 ${rt}$ :: A(2,3),s(2),u(2,2),vt(3,3)
6862
@@ -71,72 +65,110 @@ module test_linalg_svd
7165
7266 !> Simple subroutine version
7367 call svd(A,s,err=state)
74- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
75- if (error) return
76-
68+
69+ test = 'subroutine version'
70+ call check(error,state%ok(),test//': '//state%print())
71+ if (allocated(error)) return
72+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
73+ if (allocated(error)) return
74+
7775 !> Function interface
7876 s = svdvals(A,err=state)
79- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
80- if (error) return
77+
78+ test = 'function interface'
79+ call check(error,state%ok(),test//': '//state%print())
80+ if (allocated(error)) return
81+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
82+ if (allocated(error)) return
8183
8284 !> [S, U]. Singular vectors could be all flipped
8385 call svd(A,s,u,err=state)
84- error = state%error() .or. &
85- .not. all(abs(s-s_sol)<=tol) .or. &
86- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
87- if (error) return
86+
87+ test = 'subroutine with singular vectors'
88+ call check(error,state%ok(),test//': '//state%print())
89+ if (allocated(error)) return
90+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
91+ if (allocated(error)) return
92+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
93+ if (allocated(error)) return
8894
8995 !> [S, U]. Overwrite A matrix
9096 call svd(A,s,u,overwrite_a=.true.,err=state)
91- error = state%error() .or. &
92- .not. all(abs(s-s_sol)<=tol) .or. &
93- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol))
94- if (error) return
97+
98+ test = 'subroutine, overwrite_a'
99+ call check(error,state%ok(),test//': '//state%print())
100+ if (allocated(error)) return
101+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
102+ if (allocated(error)) return
103+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
104+ if (allocated(error)) return
95105
96106 !> [S, U, V^T]
97107 A = A_mat
98108 call svd(A,s,u,vt,overwrite_a=.true.,err=state)
99- error = state%error() .or. &
100- .not. all(abs(s-s_sol)<=tol) .or. &
101- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
102- .not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
103- if (error) return
104-
109+
110+ test = '[S, U, V^T]'
111+ call check(error,state%ok(),test//': '//state%print())
112+ if (allocated(error)) return
113+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
114+ if (allocated(error)) return
115+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
116+ if (allocated(error)) return
117+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
118+ if (allocated(error)) return
119+
105120 !> [S, V^T]. Do not overwrite A matrix
106121 A = A_mat
107122 call svd(A,s,vt=vt,err=state)
108- error = state%error() .or. &
109- .not. all(abs(s-s_sol)<=tol) .or. &
110- .not.(all(abs(vt+vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
111- if (error) return
123+
124+ test = '[S, V^T], overwrite_a=.false.'
125+ call check(error,state%ok(),test//': '//state%print())
126+ if (allocated(error)) return
127+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
128+ if (allocated(error)) return
129+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
130+ if (allocated(error)) return
112131
113132 !> [S, V^T]. Overwrite A matrix
114133 call svd(A,s,vt=vt,overwrite_a=.true.,err=state)
115- error = state%error() .or. &
116- .not. all(abs(s-s_sol)<=tol) .or. &
117- .not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
118- if (error) return
119-
134+
135+ test = '[S, V^T], overwrite_a=.true.'
136+ call check(error,state%ok(),test//': '//state%print())
137+ if (allocated(error)) return
138+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
139+ if (allocated(error)) return
140+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
141+ if (allocated(error)) return
142+
120143 !> [U, S, V^T].
121144 A = A_mat
122145 call svd(A,s,u,vt,err=state)
123- error = state%error() .or. &
124- .not. all(abs(s-s_sol)<=tol) .or. &
125- .not.(all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol)) .or. &
126- .not.(all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol))
127- if (error) return
146+
147+ test = '[U, S, V^T]'
148+ call check(error,state%ok(),test//': '//state%print())
149+ if (allocated(error)) return
150+ call check(error, all(abs(u-u_sol)<=tol) .or. all(abs(u+u_sol)<=tol), test//': U')
151+ if (allocated(error)) return
152+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
153+ if (allocated(error)) return
154+ call check(error, all(abs(vt-vt_sol)<=tol) .or. all(abs(vt+vt_sol)<=tol), test//': V^T')
155+ if (allocated(error)) return
128156
129157 !> [U, S, V^T]. Partial storage -> compare until k=2 columns of U rows of V^T
130158 A = A_mat
131159 u = 0
132160 vt = 0
133161 call svd(A,s,u,vt,full_matrices=.false.,err=state)
134- error = state%error() &
135- .or. .not. all(abs(s-s_sol)<=tol) &
136- .or. .not.(all(abs( u(:,:2)- u_sol(:,:2))<=tol) .or. all(abs( u(:,:2)+ u_sol(:,:2))<=tol)) &
137- .or. .not.(all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol))
138-
139- if (error) return
162+
163+ test = '[U, S, V^T], partial storage'
164+ call check(error,state%ok(),test//': '//state%print())
165+ if (allocated(error)) return
166+ call check(error, all(abs(u(:,:2)-u_sol(:,:2))<=tol) .or. all(abs(u(:,:2)+u_sol(:,:2))<=tol), test//': U(:,:2)')
167+ if (allocated(error)) return
168+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
169+ if (allocated(error)) return
170+ call check(error, all(abs(vt(:2,:)-vt_sol(:2,:))<=tol) .or. all(abs(vt(:2,:)+vt_sol(:2,:))<=tol), test//': V^T(:2,:)')
171+ if (allocated(error)) return
140172
141173 end subroutine test_svd_${ri}$
142174
@@ -147,7 +179,7 @@ module test_linalg_svd
147179 #:for ck,ct,ci in CMPLX_KINDS_TYPES
148180 #:if ck!="xdp"
149181 subroutine test_complex_svd_${ci}$(error)
150- logical, intent(out) :: error
182+ type(error_type), allocatable, intent(out) :: error
151183
152184 !> Reference solution
153185 real(${ck}$), parameter :: tol = sqrt(epsilon(0.0_${ck}$))
@@ -165,6 +197,7 @@ module test_linalg_svd
165197 ${ct}$, parameter :: vt_sol(2,2) = reshape([cone,czero,czero,cone],[2,2])
166198
167199 !> Local variables
200+ character(:), allocatable :: test
168201 type(linalg_state_type) :: state
169202 ${ct}$ :: A(2,2),u(2,2),vt(2,2)
170203 real(${ck}$) :: s(2)
@@ -174,28 +207,63 @@ module test_linalg_svd
174207
175208 !> Simple subroutine version
176209 call svd(A,s,err=state)
177- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
178- if (error) return
179-
210+
211+ test = '[S], complex subroutine'
212+ call check(error,state%ok(),test//': '//state%print())
213+ if (allocated(error)) return
214+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
215+ if (allocated(error)) return
216+
180217 !> Function interface
181218 s = svdvals(A,err=state)
182- error = state%error() .or. .not. all(abs(s-s_sol)<=tol)
183- if (error) return
219+
220+ test = 'svdvals, complex function'
221+ call check(error,state%ok(),test//': '//state%print())
222+ if (allocated(error)) return
223+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
224+ if (allocated(error)) return
184225
185226 !> [S, U, V^T]
186227 A = A_mat
187228 call svd(A,s,u,vt,overwrite_a=.true.,err=state)
188- error = state%error() .or. &
189- .not. all(abs(s-s_sol)<=tol) .or. &
190- .not. all(abs(matmul(u,matmul(diag(s),vt)) - A_mat)<=tol)
191- if (error) return
229+
230+ test = '[S, U, V^T], complex'
231+ call check(error,state%ok(),test//': '//state%print())
232+ if (allocated(error)) return
233+ call check(error, all(abs(s-s_sol)<=tol), test//': S')
234+ if (allocated(error)) return
235+ call check(error, all(abs(matmul(u,matmul(diag(s),vt))-A_mat)<=tol), test//': U*S*V^T')
236+ if (allocated(error)) return
192237
193238 end subroutine test_complex_svd_${ci}$
194239
195240 #:endif
196241 #:endfor
197242
198-
199243end module test_linalg_svd
200244
201-
245+ program test_lstsq
246+ use, intrinsic :: iso_fortran_env, only : error_unit
247+ use testdrive, only : run_testsuite, new_testsuite, testsuite_type
248+ use test_linalg_svd, only : test_svd
249+ implicit none
250+ integer :: stat, is
251+ type(testsuite_type), allocatable :: testsuites(:)
252+ character(len=*), parameter :: fmt = '("#", *(1x, a))'
253+
254+ stat = 0
255+
256+ testsuites = [ &
257+ new_testsuite("linalg_svd", test_svd) &
258+ ]
259+
260+ do is = 1, size(testsuites)
261+ write(error_unit, fmt) "Testing:", testsuites(is)%name
262+ call run_testsuite(testsuites(is)%collect, error_unit, stat)
263+ end do
264+
265+ if (stat > 0) then
266+ write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!"
267+ error stop
268+ end if
269+ end program test_lstsq
0 commit comments