1+ #include " svd_ffi.h"
12#include " nanobind/nanobind.h"
23#include " xla/ffi/api/ffi.h"
34
5+ namespace nb = nanobind;
6+
47using lapack_int = int ;
58
6- namespace nb = nanobind;
7- using namespace ::xla;
9+ namespace ffi = xla::ffi;
810
9- inline constexpr auto LapackIntDtype = ffi::DataType::S32;
10- static_assert (std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
11+ XLA_FFI_REGISTER_ENUM_ATTR_DECODING (UVtMode);
1112
1213template <ffi::DataType dtype>
1314static ffi::Error SvdOnlyVtImpl (
1415 ffi::Buffer<dtype> x,
1516 ffi::ResultBuffer<dtype> x_out,
1617 ffi::ResultBuffer<ffi::ToReal(dtype)> s,
17- ffi::ResultBuffer<dtype> vt,
18- ffi::ResultBuffer<LapackIntDtype> info) {
18+ ffi::ResultBuffer<dtype> u_or_vt,
19+ ffi::ResultBuffer<ffi::DataType::S32> info,
20+ UVtMode mode) {
1921
2022 using MachineType = ffi::NativeType<dtype>;
2123 using RealType = ffi::NativeType<ffi::ToReal (dtype)>;
@@ -76,48 +78,68 @@ static ffi::Error SvdOnlyVtImpl(
7678
7779 const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
7880
79- ffi::Span<const int64_t > dims = x.dimensions ();
81+ const ffi::Span<const int64_t > dims = x.dimensions ();
8082 if (dims.size () != 2 ) {
8183 return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only 2d arrays supported as input." );
8284 }
83- int64_t x_rows = dims.front ();
84- int64_t x_cols = dims.back ();
85+ const int64_t x_rows = dims.front ();
86+ const int64_t x_cols = dims.back ();
8587
86- if (x_rows < x_cols) [[unlikely]] {
88+ if (mode == UVtMode::computeOnlyU && x_rows > x_cols) [[unlikely]] {
89+ return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M <= N supported." );
90+ } else if (mode == UVtMode::computeOnlyVt && x_rows < x_cols) [[unlikely]] {
8791 return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M >= N supported." );
8892 }
8993
9094 if (x_rows > lapack_int_max || x_cols > lapack_int_max) [[unlikely]] {
9195 return ffi::Error (ffi::ErrorCode::kOutOfRange , " Dimension of input out of range for lapack integer." );
9296 }
9397
94- lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
95- lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
98+ const lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
99+ const lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
96100
97101 auto * x_out_data = x_out->typed_data ();
98102 auto * s_data = s->typed_data ();
99- auto * vt_data = vt ->typed_data ();
103+ auto * u_or_vt_data = u_or_vt ->typed_data ();
100104 auto * info_data = info->typed_data ();
101105
106+ MachineType* u_data;
107+ MachineType* vt_data;
108+ if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
109+ u_data = u_or_vt_data;
110+ vt_data = nullptr ;
111+ } else {
112+ u_data = nullptr ;
113+ vt_data = u_or_vt_data;
114+ }
115+
102116 if (x.typed_data () != x_out_data) {
103117 std::copy_n (x.typed_data (), x.element_count (), x_out_data);
104118 }
105119
106120 ffi::NativeType<dtype> work_size = {};
107121 lapack_int lwork = -1 ;
108- char jobz = ' O' ;
109- lapack_int ldu = 1 ;
122+ const char jobz = ' O' ;
123+ lapack_int ldu;
124+ lapack_int ldvt;
125+ if (mode == UVtMode::computeOnlyU && x_rows < x_cols) {
126+ ldu = x_rows_lapack;
127+ ldvt = 1 ;
128+ } else {
129+ ldu = 1 ;
130+ ldvt = x_cols_lapack;
131+ }
110132
111133 if constexpr (ffi::IsComplexType<dtype>()) {
112134 fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
113135 &x_rows_lapack, nullptr , nullptr ,
114- &ldu, nullptr , &x_cols_lapack , &work_size,
136+ &ldu, nullptr , &ldvt , &work_size,
115137 &lwork, nullptr , nullptr , info_data
116138 );
117139 } else {
118140 fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
119141 &x_rows_lapack, nullptr , nullptr ,
120- &ldu, nullptr , &x_cols_lapack ,
142+ &ldu, nullptr , &ldvt ,
121143 &work_size, &lwork, nullptr , info_data
122144 );
123145 }
@@ -147,14 +169,14 @@ static ffi::Error SvdOnlyVtImpl(
147169
148170 if constexpr (ffi::IsComplexType<dtype>()) {
149171 fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
150- &x_rows_lapack, s_data, nullptr ,
151- &ldu, vt_data, &x_cols_lapack , work.get (),
172+ &x_rows_lapack, s_data, u_data ,
173+ &ldu, vt_data, &ldvt , work.get (),
152174 &lwork, rwork.get (), iwork.get (), info_data
153175 );
154176 } else {
155177 fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
156- &x_rows_lapack, s_data, nullptr ,
157- &ldu, vt_data, &x_cols_lapack ,
178+ &x_rows_lapack, s_data, u_data ,
179+ &ldu, vt_data, &ldvt ,
158180 work.get (), &lwork, iwork.get (), info_data
159181 );
160182 }
@@ -171,7 +193,8 @@ static ffi::Error SvdOnlyVtQRImpl(
171193 ffi::Buffer<dtype> x,
172194 ffi::ResultBuffer<dtype> x_out,
173195 ffi::ResultBuffer<ffi::ToReal(dtype)> s,
174- ffi::ResultBuffer<LapackIntDtype> info) {
196+ ffi::ResultBuffer<ffi::DataType::S32> info,
197+ UVtMode mode) {
175198
176199 using MachineType = ffi::NativeType<dtype>;
177200 using RealType = ffi::NativeType<ffi::ToReal (dtype)>;
@@ -230,23 +253,25 @@ static ffi::Error SvdOnlyVtQRImpl(
230253
231254 const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
232255
233- ffi::Span<const int64_t > dims = x.dimensions ();
256+ const ffi::Span<const int64_t > dims = x.dimensions ();
234257 if (dims.size () != 2 ) {
235258 return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only 2d arrays supported as input." );
236259 }
237- int64_t x_rows = dims.front ();
238- int64_t x_cols = dims.back ();
260+ const int64_t x_rows = dims.front ();
261+ const int64_t x_cols = dims.back ();
239262
240- if (x_rows < x_cols) [[unlikely]] {
263+ if (mode == UVtMode::computeOnlyU && x_rows > x_cols) [[unlikely]] {
264+ return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M <= N supported." );
265+ } else if (mode == UVtMode::computeOnlyVt && x_rows < x_cols) [[unlikely]] {
241266 return ffi::Error (ffi::ErrorCode::kInvalidArgument , " Only matrices with M >= N supported." );
242267 }
243268
244269 if (x_rows > lapack_int_max || x_cols > lapack_int_max) [[unlikely]] {
245270 return ffi::Error (ffi::ErrorCode::kOutOfRange , " Dimension of input out of range for lapack integer." );
246271 }
247272
248- lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
249- lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
273+ const lapack_int x_rows_lapack = static_cast <lapack_int>(x_rows);
274+ const lapack_int x_cols_lapack = static_cast <lapack_int>(x_cols);
250275
251276 auto * x_out_data = x_out->typed_data ();
252277 auto * s_data = s->typed_data ();
@@ -259,20 +284,33 @@ static ffi::Error SvdOnlyVtQRImpl(
259284
260285 ffi::NativeType<dtype> work_size = {};
261286 lapack_int lwork = -1 ;
262- char jobu = ' N' ;
263- char jobvt = ' O' ;
264- lapack_int ldu = 1 ;
287+
288+ char jobu;
289+ char jobvt;
290+ const lapack_int ldu = 1 ;
291+ const lapack_int ldvt = 1 ;
292+ if (mode == UVtMode::computeOnlyU) {
293+ jobu = ' O' ;
294+ jobvt = ' N' ;
295+ // ldu = 1;
296+ // ldvt = 1;
297+ } else {
298+ jobu = ' N' ;
299+ jobvt = ' O' ;
300+ // ldu = 1;
301+ // ldvt = 1;
302+ }
265303
266304 if constexpr (ffi::IsComplexType<dtype>()) {
267305 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
268306 &x_rows_lapack, nullptr , nullptr ,
269- &ldu, nullptr , &x_cols_lapack , &work_size,
307+ &ldu, nullptr , &ldvt , &work_size,
270308 &lwork, nullptr , info_data
271309 );
272310 } else {
273311 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
274312 &x_rows_lapack, nullptr , nullptr ,
275- &ldu, nullptr , &x_cols_lapack ,
313+ &ldu, nullptr , &ldvt ,
276314 &work_size, &lwork, info_data
277315 );
278316 }
@@ -300,13 +338,13 @@ static ffi::Error SvdOnlyVtQRImpl(
300338 if constexpr (ffi::IsComplexType<dtype>()) {
301339 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
302340 &x_rows_lapack, s_data, nullptr ,
303- &ldu, nullptr , &x_cols_lapack , work.get (),
341+ &ldu, nullptr , &ldvt , work.get (),
304342 &lwork, rwork.get (), info_data
305343 );
306344 } else {
307345 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
308346 &x_rows_lapack, s_data, nullptr ,
309- &ldu, nullptr , &x_cols_lapack ,
347+ &ldu, nullptr , &ldvt ,
310348 work.get (), &lwork, info_data
311349 );
312350 }
@@ -318,56 +356,60 @@ static ffi::Error SvdOnlyVtQRImpl(
318356 return ffi::Error::Success ();
319357}
320358
321- #define DEFINE_REAL_SVD_ONLY_VT (fname, dtype ) \
322- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
323- fname, SvdOnlyVtImpl<dtype>, \
324- ffi::Ffi::Bind () \
325- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
326- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
327- .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
328- .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
329- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
330-
331- #define DEFINE_COMPLEX_SVD_ONLY_VT (fname, dtype ) \
332- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
333- fname, SvdOnlyVtImpl<dtype>, \
334- ffi::Ffi::Bind () \
335- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
336- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
337- .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
338- .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
339- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
359+ #define DEFINE_REAL_SVD_ONLY_VT (fname, dtype ) \
360+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
361+ fname, SvdOnlyVtImpl<dtype>, \
362+ ffi::Ffi::Bind () \
363+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
364+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
365+ .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
366+ .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
367+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
368+ .Attr<UVtMode>(" mode" ))
369+
370+ #define DEFINE_COMPLEX_SVD_ONLY_VT (fname, dtype ) \
371+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
372+ fname, SvdOnlyVtImpl<dtype>, \
373+ ffi::Ffi::Bind () \
374+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
375+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
376+ .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
377+ .Ret<ffi::Buffer<dtype>>(/* vt*/ ) \
378+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
379+ .Attr<UVtMode>(" mode" ))
340380
341381DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f32, ffi::DataType::F32);
342382DEFINE_REAL_SVD_ONLY_VT (svd_only_vt_f64, ffi::DataType::F64);
343383DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c64, ffi::DataType::C64);
344384DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c128, ffi::DataType::C128);
345385
346- #define DEFINE_REAL_SVD_ONLY_VT_QR (fname, dtype ) \
347- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
348- fname, SvdOnlyVtQRImpl<dtype>, \
349- ffi::Ffi::Bind () \
350- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
351- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
352- .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
353- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
354-
355- #define DEFINE_COMPLEX_SVD_ONLY_VT_QR (fname, dtype ) \
356- XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
357- fname, SvdOnlyVtQRImpl<dtype>, \
358- ffi::Ffi::Bind () \
359- .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
360- .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
361- .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
362- .Ret<ffi::Buffer<LapackIntDtype>>(/* info*/ ))
386+ #define DEFINE_REAL_SVD_ONLY_VT_QR (fname, dtype ) \
387+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
388+ fname, SvdOnlyVtQRImpl<dtype>, \
389+ ffi::Ffi::Bind () \
390+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
391+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
392+ .Ret<ffi::Buffer<dtype>>(/* s*/ ) \
393+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
394+ .Attr<UVtMode>(" mode" ))
395+
396+ #define DEFINE_COMPLEX_SVD_ONLY_VT_QR (fname, dtype ) \
397+ XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
398+ fname, SvdOnlyVtQRImpl<dtype>, \
399+ ffi::Ffi::Bind () \
400+ .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
401+ .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
402+ .Ret<ffi::Buffer<ffi::ToReal(dtype)>>(/* s*/ ) \
403+ .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
404+ .Attr<UVtMode>(" mode" ))
363405
364406DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f32, ffi::DataType::F32);
365407DEFINE_REAL_SVD_ONLY_VT_QR (svd_only_vt_qr_f64, ffi::DataType::F64);
366408DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c64, ffi::DataType::C64);
367409DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c128, ffi::DataType::C128);
368410
369411template <typename T>
370- nb::capsule EncapsulateFfiCall (T *fn) {
412+ static nb::capsule EncapsulateFfiCall (T *fn) {
371413 static_assert (std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
372414 " Encapsulated function must be and XLA FFI handler" );
373415 return nb::capsule (reinterpret_cast <void *>(fn));
0 commit comments