Skip to content

Commit c579f74

Browse files
authored
Merge pull request #150 from wolfv/fix_svd_horizontal_vertical
fix svd for horizontal and vertical
2 parents 0ab6257 + 327b385 commit c579f74

File tree

4 files changed

+100
-22
lines changed

4 files changed

+100
-22
lines changed

include/xtensor-blas/xblas_utils.hpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,26 @@ namespace xt
185185
std::false_type is_xfunction_impl(...);
186186
}
187187

188-
template<typename T>
188+
template<class T>
189189
constexpr bool is_xfunction(T&& t) {
190190
return decltype(detail::is_xfunction_impl(t))::value;
191191
}
192+
193+
/***********************************
194+
* assert_nd_square implementation *
195+
***********************************/
196+
197+
template <class T>
198+
#if !defined(_MSC_VER) || _MSC_VER >= 1910
199+
constexpr
200+
#endif
201+
void assert_nd_square(const xexpression<T>& t)
202+
{
203+
auto& dt = t.derived_cast();
204+
if (dt.shape()[dt.dimension() - 1] != dt.shape()[dt.dimension() - 2])
205+
{
206+
throw std::runtime_error("Last 2 dimensions of the array must be square.");
207+
}
208+
}
192209
}
193210
#endif

include/xtensor-blas/xlapack.hpp

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,15 @@ namespace lapack
8989
n = static_cast<blas_index_t>(A.shape()[1]);
9090
}
9191

92+
blas_index_t m = static_cast<blas_index_t>(A.shape()[0]);
93+
blas_index_t a_stride = std::max(blas_index_t(1), m);
94+
9295
int info = cxxlapack::orgqr<blas_index_t>(
93-
static_cast<blas_index_t>(A.shape()[0]),
96+
m,
9497
n,
9598
static_cast<blas_index_t>(tau.size()),
9699
A.data(),
97-
stride_back(A),
100+
a_stride,
98101
tau.data(),
99102
work.data(),
100103
static_cast<blas_index_t>(-1)
@@ -108,11 +111,11 @@ namespace lapack
108111
work.resize(static_cast<std::size_t>(work[0]));
109112

110113
info = cxxlapack::orgqr<blas_index_t>(
111-
static_cast<blas_index_t>(A.shape()[0]),
114+
m,
112115
n,
113116
static_cast<blas_index_t>(tau.size()),
114117
A.data(),
115-
stride_back(A),
118+
a_stride,
116119
tau.data(),
117120
work.data(),
118121
static_cast<blas_index_t>(work.size())
@@ -133,12 +136,15 @@ namespace lapack
133136
n = static_cast<blas_index_t>(A.shape()[1]);
134137
}
135138

139+
blas_index_t m = static_cast<blas_index_t>(A.shape()[0]);
140+
blas_index_t a_stride = std::max(blas_index_t(1), m);
141+
136142
int info = cxxlapack::ungqr<blas_index_t>(
137-
static_cast<blas_index_t>(A.shape()[0]),
143+
m,
138144
n,
139145
static_cast<blas_index_t>(tau.size()),
140146
A.data(),
141-
stride_back(A),
147+
a_stride,
142148
tau.data(),
143149
work.data(),
144150
static_cast<blas_index_t>(-1)
@@ -152,11 +158,11 @@ namespace lapack
152158
work.resize(static_cast<std::size_t>(std::real(work[0])));
153159

154160
info = cxxlapack::ungqr<blas_index_t>(
155-
static_cast<blas_index_t>(A.shape()[0]),
161+
m,
156162
n,
157163
static_cast<blas_index_t>(tau.size()),
158164
A.data(),
159-
stride_back(A),
165+
a_stride,
160166
tau.data(),
161167
work.data(),
162168
static_cast<blas_index_t>(work.size())
@@ -174,12 +180,14 @@ namespace lapack
174180
XTENSOR_ASSERT(A.layout() == layout_type::column_major);
175181

176182
uvector<value_type> work(1);
183+
blas_index_t m = static_cast<blas_index_t>(A.shape()[0]);
184+
blas_index_t a_stride = std::max(blas_index_t(1), m);
177185

178186
int info = cxxlapack::geqrf<blas_index_t>(
179-
static_cast<blas_index_t>(A.shape()[0]),
187+
m,
180188
static_cast<blas_index_t>(A.shape()[1]),
181189
A.data(),
182-
stride_back(A),
190+
a_stride,
183191
tau.data(),
184192
work.data(),
185193
static_cast<blas_index_t>(-1)
@@ -193,10 +201,10 @@ namespace lapack
193201
work.resize(static_cast<std::size_t>(std::real(work[0])));
194202

195203
info = cxxlapack::geqrf<blas_index_t>(
196-
static_cast<blas_index_t>(A.shape()[0]),
204+
m,
197205
static_cast<blas_index_t>(A.shape()[1]),
198206
A.data(),
199-
stride_back(A),
207+
a_stride,
200208
tau.data(),
201209
work.data(),
202210
static_cast<blas_index_t>(work.size())
@@ -241,7 +249,9 @@ namespace lapack
241249
return m >= n ? std::make_pair(1, stride_back(vt)) :
242250
std::make_pair(stride_back(u), 1);
243251
}
244-
return std::make_pair(stride_back(u), stride_back(vt));
252+
253+
return std::make_pair(std::max(blas_index_t(u.shape()[0]), 1),
254+
std::max(blas_index_t(vt.shape()[0]), 1));
245255
}
246256
}
247257

@@ -269,13 +279,14 @@ namespace lapack
269279
std::tie(u_stride, vt_stride) = detail::init_u_vt(u, vt, jobz, m, n);
270280

271281
uvector<blas_index_t> iwork(8 * std::min(m, n));
282+
blas_index_t a_stride = static_cast<blas_index_t>(std::max(std::size_t(1), m));
272283

273284
int info = cxxlapack::gesdd<blas_index_t>(
274285
jobz,
275286
static_cast<blas_index_t>(A.shape()[0]),
276287
static_cast<blas_index_t>(A.shape()[1]),
277288
A.data(),
278-
stride_back(A),
289+
a_stride,
279290
s.data(),
280291
u.data(),
281292
u_stride,
@@ -292,13 +303,12 @@ namespace lapack
292303
}
293304

294305
work.resize(static_cast<std::size_t>(work[0]));
295-
296306
info = cxxlapack::gesdd<blas_index_t>(
297307
jobz,
298308
static_cast<blas_index_t>(A.shape()[0]),
299309
static_cast<blas_index_t>(A.shape()[1]),
300310
A.data(),
301-
stride_back(A),
311+
a_stride,
302312
s.data(),
303313
u.data(),
304314
u_stride,
@@ -355,13 +365,14 @@ namespace lapack
355365

356366
blas_index_t u_stride, vt_stride;
357367
std::tie(u_stride, vt_stride) = detail::init_u_vt(u, vt, jobz, m, n);
368+
blas_index_t a_stride = static_cast<blas_index_t>(std::max(std::size_t(1), m));
358369

359370
int info = cxxlapack::gesdd<blas_index_t>(
360371
jobz,
361372
static_cast<blas_index_t>(A.shape()[0]),
362373
static_cast<blas_index_t>(A.shape()[1]),
363374
A.data(),
364-
stride_back(A),
375+
a_stride,
365376
s.data(),
366377
u.data(),
367378
u_stride,
@@ -384,7 +395,7 @@ namespace lapack
384395
static_cast<blas_index_t>(A.shape()[0]),
385396
static_cast<blas_index_t>(A.shape()[1]),
386397
A.data(),
387-
stride_back(A),
398+
a_stride,
388399
s.data(),
389400
u.data(),
390401
u_stride,

include/xtensor-blas/xlinalg.hpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ namespace linalg
226226
template <class E1, class E2>
227227
auto solve(const xexpression<E1>& A, const xexpression<E2>& b)
228228
{
229+
assert_nd_square(A);
229230
auto dA = copy_to_layout<layout_type::column_major>(A.derived_cast());
230231
auto db = copy_to_layout<layout_type::column_major>(b.derived_cast());
231232

@@ -248,6 +249,7 @@ namespace linalg
248249
template <class E1>
249250
auto inv(const xexpression<E1>& A)
250251
{
252+
assert_nd_square(A);
251253
auto dA = copy_to_layout<layout_type::column_major>(A.derived_cast());
252254

253255
uvector<blas_index_t> piv(std::min(dA.shape()[0], dA.shape()[1]));
@@ -299,6 +301,7 @@ namespace linalg
299301
using underlying_type = typename E::value_type;
300302
using value_type = typename E::value_type;
301303

304+
assert_nd_square(A);
302305
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
303306

304307
std::size_t N = M.shape()[0];
@@ -348,6 +351,7 @@ namespace linalg
348351
{
349352
using value_type = typename E::value_type;
350353

354+
assert_nd_square(A);
351355
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
352356

353357
std::size_t N = M.shape()[0];
@@ -380,6 +384,7 @@ namespace linalg
380384
{
381385
using value_type = typename E::value_type;
382386

387+
assert_nd_square(A);
383388
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
384389

385390
std::size_t N = M.shape()[0];
@@ -401,6 +406,7 @@ namespace linalg
401406
using value_type = typename E::value_type;
402407
using underlying_value_type = typename value_type::value_type;
403408

409+
assert_nd_square(A);
404410
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
405411

406412
std::size_t N = M.shape()[0];
@@ -424,10 +430,11 @@ namespace linalg
424430
* @return xtensor containing the eigenvalues.
425431
*/
426432
template <class E, std::enable_if_t<!xtl::is_complex<typename E::value_type>::value>* = nullptr>
427-
auto eigh(const xexpression<E>& A, const xexpression<E>& B,const char UPLO = 'L')
433+
auto eigh(const xexpression<E>& A, const xexpression<E>& B, const char UPLO = 'L')
428434
{
429435
using value_type = typename E::value_type;
430436

437+
assert_nd_square(A);
431438
auto M1 = copy_to_layout<layout_type::column_major>(A.derived_cast());
432439
auto M2 = copy_to_layout<layout_type::column_major>(B.derived_cast());
433440

@@ -478,6 +485,7 @@ namespace linalg
478485
{
479486
using value_type = typename E::value_type;
480487

488+
assert_nd_square(A);
481489
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
482490

483491
std::size_t N = M.shape()[0];
@@ -511,6 +519,7 @@ namespace linalg
511519
{
512520
using value_type = typename E::value_type;
513521

522+
assert_nd_square(A);
514523
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
515524

516525
std::size_t N = M.shape()[0];
@@ -545,6 +554,7 @@ namespace linalg
545554
{
546555
using value_type = typename E::value_type;
547556

557+
assert_nd_square(A);
548558
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
549559

550560
std::size_t N = M.shape()[0];
@@ -566,6 +576,7 @@ namespace linalg
566576
using value_type = typename E::value_type;
567577
using underlying_value_type = typename value_type::value_type;
568578

579+
assert_nd_square(A);
569580
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
570581

571582
std::size_t N = M.shape()[0];
@@ -989,8 +1000,9 @@ namespace linalg
9891000
auto det(const xexpression<T>& A)
9901001
{
9911002
using value_type = typename T::value_type;
992-
xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
1003+
assert_nd_square(A);
9931004

1005+
xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
9941006
uvector<blas_index_t> piv(std::min(LU.shape()[0], LU.shape()[1]));
9951007

9961008
lapack::getrf(LU, piv);
@@ -1025,6 +1037,7 @@ namespace linalg
10251037
auto slogdet(const xexpression<T>& A)
10261038
{
10271039
using value_type = typename T::value_type;
1040+
assert_nd_square(A);
10281041

10291042
xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
10301043
uvector<blas_index_t> piv(std::min(LU.shape()[0], LU.shape()[1]));
@@ -1059,6 +1072,8 @@ namespace linalg
10591072
auto slogdet(const xexpression<T>& A)
10601073
{
10611074
using value_type = typename T::value_type;
1075+
assert_nd_square(A);
1076+
10621077
xtensor<value_type, 2, layout_type::column_major> LU = A.derived_cast();
10631078
uvector<blas_index_t> piv(std::min(LU.shape()[0], LU.shape()[1]));
10641079

@@ -1214,6 +1229,7 @@ namespace linalg
12141229
template <class T>
12151230
auto cholesky(const xexpression<T>& A)
12161231
{
1232+
assert_nd_square(A);
12171233
auto M = copy_to_layout<layout_type::column_major>(A.derived_cast());
12181234

12191235
int info = lapack::potr(M, 'L');

test/test_linalg.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,19 @@ namespace xt
134134
EXPECT_TRUE(allclose(std::get<2>(res), expected_2));
135135
}
136136

137+
TEST(xlinalg, svd_horizontal_vertical)
138+
{
139+
xarray<double> a = xt::ones<double>({3, 1});
140+
xarray<double> b = xt::ones<double>({1, 3});
141+
xarray<double> u, s, vt;
142+
143+
std::tie(u, s, vt) = linalg::svd(a, false);
144+
EXPECT_TRUE(allclose(a, xt::linalg::dot(u * s, vt)));
145+
146+
std::tie(u, s, vt) = linalg::svd(b, false);
147+
EXPECT_TRUE(allclose(b, xt::linalg::dot(u * s, vt)));
148+
}
149+
137150
TEST(xlinalg, matrix_rank)
138151
{
139152
xarray<double> eall = eye<double>(4);
@@ -590,8 +603,29 @@ namespace xt
590603

591604
auto res = xt::linalg::dot(A1, A2);
592605
EXPECT_EQ(res(), 94);
606+
}
593607

594-
608+
TEST(xlinalg, asserts)
609+
{
610+
EXPECT_THROW(xt::linalg::eigh(xt::ones<double>({3, 1})), std::runtime_error);
611+
EXPECT_THROW(xt::linalg::eig(xt::ones<double>({3, 1})), std::runtime_error);
612+
EXPECT_THROW(xt::linalg::solve(xt::ones<double>({3, 1}), xt::ones<double>({3, 1})), std::runtime_error);
613+
EXPECT_THROW(xt::linalg::inv(xt::ones<double>({3, 1})), std::runtime_error);
614+
EXPECT_THROW(xt::linalg::eigvals(xt::ones<double>({3, 1})), std::runtime_error);
615+
EXPECT_THROW(xt::linalg::eigvalsh(xt::ones<double>({3, 1})), std::runtime_error);
616+
EXPECT_THROW(xt::linalg::det(xt::ones<double>({3, 1})), std::runtime_error);
617+
EXPECT_THROW(xt::linalg::slogdet(xt::ones<double>({3, 1})), std::runtime_error);
618+
EXPECT_THROW(xt::linalg::cholesky(xt::ones<double>({3, 1})), std::runtime_error);
619+
620+
EXPECT_THROW(xt::linalg::eigh(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
621+
EXPECT_THROW(xt::linalg::eig(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
622+
EXPECT_THROW(xt::linalg::solve(xt::ones<std::complex<double>>({3, 1}), xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
623+
EXPECT_THROW(xt::linalg::inv(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
624+
EXPECT_THROW(xt::linalg::eigvals(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
625+
EXPECT_THROW(xt::linalg::eigvalsh(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
626+
EXPECT_THROW(xt::linalg::det(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
627+
EXPECT_THROW(xt::linalg::slogdet(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
628+
EXPECT_THROW(xt::linalg::cholesky(xt::ones<std::complex<double>>({3, 1})), std::runtime_error);
595629
}
596630

597631
}

0 commit comments

Comments
 (0)