@@ -105,7 +105,7 @@ static ffi::Error SvdOnlyVtImpl(
105105
106106 MachineType* u_data;
107107 MachineType* vt_data;
108- if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
108+ if (( mode == UVtMode::computeOnlyU || mode == UVtMode::computePartialUandVt) && x_rows < x_cols) {
109109 u_data = u_or_vt_data;
110110 vt_data = nullptr ;
111111 } else {
@@ -122,7 +122,7 @@ static ffi::Error SvdOnlyVtImpl(
122122 const char jobz = ' O' ;
123123 lapack_int ldu;
124124 lapack_int ldvt;
125- if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
125+ if (( mode == UVtMode::computeOnlyU || mode == UVtMode::computePartialUandVt) && x_rows < x_cols) {
126126 ldu = x_rows_lapack;
127127 ldvt = 1 ;
128128 } else {
@@ -193,6 +193,7 @@ static ffi::Error SvdOnlyVtQRImpl(
193193 ffi::Buffer<dtype> x,
194194 ffi::ResultBuffer<dtype> x_out,
195195 ffi::ResultBuffer<ffi::ToReal(dtype)> s,
196+ ffi::ResultBuffer<dtype> u_or_vt,
196197 ffi::ResultBuffer<ffi::DataType::S32> info,
197198 UVtMode mode) {
198199
@@ -275,9 +276,12 @@ static ffi::Error SvdOnlyVtQRImpl(
275276
276277 auto * x_out_data = x_out->typed_data ();
277278 auto * s_data = s->typed_data ();
278- // auto* vt_data = vt ->typed_data();
279+ auto * u_or_vt_data = u_or_vt ->typed_data ();
279280 auto * info_data = info->typed_data ();
280281
282+ MachineType* u_data;
283+ MachineType* vt_data;
284+
281285 if (x.typed_data () != x_out_data) {
282286 std::copy_n (x.typed_data (), x.element_count (), x_out_data);
283287 }
@@ -287,18 +291,38 @@ static ffi::Error SvdOnlyVtQRImpl(
287291
288292 char jobu;
289293 char jobvt;
290- const lapack_int ldu = 1 ;
291- const lapack_int ldvt = 1 ;
294+ lapack_int ldu;
295+ lapack_int ldvt;
292296 if (mode == UVtMode::computeOnlyU) {
293297 jobu = ' O' ;
294298 jobvt = ' N' ;
295- // ldu = 1;
296- // ldvt = 1;
297- } else {
299+ ldu = 1 ;
300+ ldvt = 1 ;
301+ u_data = nullptr ;
302+ vt_data = nullptr ;
303+ } else if (mode == UVtMode::computeOnlyVt) {
298304 jobu = ' N' ;
299305 jobvt = ' O' ;
300- // ldu = 1;
301- // ldvt = 1;
306+ ldu = 1 ;
307+ ldvt = 1 ;
308+ u_data = nullptr ;
309+ vt_data = nullptr ;
310+ } else {
311+ if (x_rows >= x_cols) {
312+ jobu = ' O' ;
313+ jobvt = ' S' ;
314+ ldu = 1 ;
315+ ldvt = x_cols_lapack;
316+ u_data = nullptr ;
317+ vt_data = u_or_vt_data;
318+ } else {
319+ jobu = ' S' ;
320+ jobvt = ' O' ;
321+ ldu = x_rows_lapack;
322+ ldvt = 1 ;
323+ u_data = u_or_vt_data;
324+ vt_data = nullptr ;
325+ }
302326 }
303327
304328 if constexpr (ffi::IsComplexType<dtype>()) {
@@ -337,14 +361,14 @@ static ffi::Error SvdOnlyVtQRImpl(
337361
338362 if constexpr (ffi::IsComplexType<dtype>()) {
339363 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
340- &x_rows_lapack, s_data, nullptr ,
341- &ldu, nullptr , &ldvt, work.get (),
364+ &x_rows_lapack, s_data, u_data ,
365+ &ldu, vt_data , &ldvt, work.get (),
342366 &lwork, rwork.get (), info_data
343367 );
344368 } else {
345369 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
346- &x_rows_lapack, s_data, nullptr ,
347- &ldu, nullptr , &ldvt,
370+ &x_rows_lapack, s_data, u_data ,
371+ &ldu, vt_data , &ldvt,
348372 work.get (), &lwork, info_data
349373 );
350374 }
@@ -363,7 +387,7 @@ static ffi::Error SvdOnlyVtQRImpl(
363387 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
364388 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
365389 .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
366- .Ret<ffi::Buffer<dtype>>(/* vt */ ) \
390+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt */ ) \
367391 .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
368392 .Attr<UVtMode>(" mode" ))
369393
@@ -374,7 +398,7 @@ static ffi::Error SvdOnlyVtQRImpl(
374398 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
375399 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
376400 .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
377- .Ret<ffi::Buffer<dtype>>(/* vt */ ) \
401+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt */ ) \
378402 .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
379403 .Attr<UVtMode>(" mode" ))
380404
@@ -390,6 +414,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
390414 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
391415 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
392416 .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
417+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt*/ ) \
393418 .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
394419 .Attr<UVtMode>(" mode" ))
395420
@@ -400,6 +425,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
400425 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
401426 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
402427 .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
428+ .Ret<ffi::Buffer<dtype>>(/* u_or_vt*/ ) \
403429 .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
404430 .Attr<UVtMode>(" mode" ))
405431
0 commit comments