@@ -11,7 +11,7 @@ namespace ffi = xla::ffi;
1111XLA_FFI_REGISTER_ENUM_ATTR_DECODING (UVtMode);
1212
1313template <ffi::DataType dtype>
14- static ffi::Error SvdOnlyVtImpl (
14+ static ffi::Error SvdOnlyUVtImpl (
1515 ffi::Buffer<dtype> x,
1616 ffi::ResultBuffer<dtype> x_out,
1717 ffi::ResultBuffer<ffi::ToReal(dtype)> s,
@@ -189,7 +189,7 @@ static ffi::Error SvdOnlyVtImpl(
189189}
190190
191191template <ffi::DataType dtype>
192- static ffi::Error SvdOnlyVtQRImpl (
192+ static ffi::Error SvdOnlyUVtQRImpl (
193193 ffi::Buffer<dtype> x,
194194 ffi::ResultBuffer<dtype> x_out,
195195 ffi::ResultBuffer<ffi::ToReal(dtype)> s,
@@ -382,7 +382,7 @@ static ffi::Error SvdOnlyVtQRImpl(
382382
383383#define DEFINE_REAL_SVD_ONLY_VT (fname, dtype ) \
384384 XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
385- fname, SvdOnlyVtImpl <dtype>, \
385+ fname, SvdOnlyUVtImpl <dtype>, \
386386 ffi::Ffi::Bind () \
387387 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
388388 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -393,7 +393,7 @@ static ffi::Error SvdOnlyVtQRImpl(
393393
394394#define DEFINE_COMPLEX_SVD_ONLY_VT (fname, dtype ) \
395395 XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
396- fname, SvdOnlyVtImpl <dtype>, \
396+ fname, SvdOnlyUVtImpl <dtype>, \
397397 ffi::Ffi::Bind () \
398398 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
399399 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -402,14 +402,14 @@ static ffi::Error SvdOnlyVtQRImpl(
402402 .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
403403 .Attr<UVtMode>(" mode" ))
404404
405- DEFINE_REAL_SVD_ONLY_VT(svd_only_vt_f32 , ffi::DataType::F32);
406- DEFINE_REAL_SVD_ONLY_VT (svd_only_vt_f64 , ffi::DataType::F64);
407- DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c64 , ffi::DataType::C64);
408- DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_vt_c128 , ffi::DataType::C128);
405+ DEFINE_REAL_SVD_ONLY_VT(svd_only_u_vt_f32 , ffi::DataType::F32);
406+ DEFINE_REAL_SVD_ONLY_VT (svd_only_u_vt_f64 , ffi::DataType::F64);
407+ DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_u_vt_c64 , ffi::DataType::C64);
408+ DEFINE_COMPLEX_SVD_ONLY_VT (svd_only_u_vt_c128 , ffi::DataType::C128);
409409
410410#define DEFINE_REAL_SVD_ONLY_VT_QR (fname, dtype ) \
411411 XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
412- fname, SvdOnlyVtQRImpl <dtype>, \
412+ fname, SvdOnlyUVtQRImpl <dtype>, \
413413 ffi::Ffi::Bind () \
414414 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
415415 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -420,7 +420,7 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
420420
421421#define DEFINE_COMPLEX_SVD_ONLY_VT_QR (fname, dtype ) \
422422 XLA_FFI_DEFINE_HANDLER_SYMBOL ( \
423- fname, SvdOnlyVtQRImpl <dtype>, \
423+ fname, SvdOnlyUVtQRImpl <dtype>, \
424424 ffi::Ffi::Bind () \
425425 .Arg<ffi::Buffer<dtype>>(/* x*/ ) \
426426 .Ret<ffi::Buffer<dtype>>(/* x_out*/ ) \
@@ -429,10 +429,10 @@ DEFINE_COMPLEX_SVD_ONLY_VT(svd_only_vt_c128, ffi::DataType::C128);
429429 .Ret<ffi::Buffer<ffi::DataType::S32>>(/* info*/ ) \
430430 .Attr<UVtMode>(" mode" ))
431431
432- DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_vt_qr_f32 , ffi::DataType::F32);
433- DEFINE_REAL_SVD_ONLY_VT_QR (svd_only_vt_qr_f64 , ffi::DataType::F64);
434- DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c64 , ffi::DataType::C64);
435- DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_vt_qr_c128 , ffi::DataType::C128);
432+ DEFINE_REAL_SVD_ONLY_VT_QR(svd_only_u_vt_qr_f32 , ffi::DataType::F32);
433+ DEFINE_REAL_SVD_ONLY_VT_QR (svd_only_u_vt_qr_f64 , ffi::DataType::F64);
434+ DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_u_vt_qr_c64 , ffi::DataType::C64);
435+ DEFINE_COMPLEX_SVD_ONLY_VT_QR (svd_only_u_vt_qr_c128 , ffi::DataType::C128);
436436
437437template <typename T>
438438static nb::capsule EncapsulateFfiCall (T *fn) {
@@ -441,13 +441,13 @@ static nb::capsule EncapsulateFfiCall(T *fn) {
441441 return nb::capsule (reinterpret_cast <void *>(fn));
442442}
443443
444- NB_MODULE (_svd_only_vt , m) {
445- m.def (" svd_only_vt_f32 " , []() { return EncapsulateFfiCall (svd_only_vt_f32 ); });
446- m.def (" svd_only_vt_f64 " , []() { return EncapsulateFfiCall (svd_only_vt_f64 ); });
447- m.def (" svd_only_vt_c64 " , []() { return EncapsulateFfiCall (svd_only_vt_c64 ); });
448- m.def (" svd_only_vt_c128 " , []() { return EncapsulateFfiCall (svd_only_vt_c128 ); });
449- m.def (" svd_only_vt_qr_f32 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_f32 ); });
450- m.def (" svd_only_vt_qr_f64 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_f64 ); });
451- m.def (" svd_only_vt_qr_c64 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_c64 ); });
452- m.def (" svd_only_vt_qr_c128 " , []() { return EncapsulateFfiCall (svd_only_vt_qr_c128 ); });
444+ NB_MODULE (_svd_only_u_vt , m) {
445+ m.def (" svd_only_u_vt_f32 " , []() { return EncapsulateFfiCall (svd_only_u_vt_f32 ); });
446+ m.def (" svd_only_u_vt_f64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_f64 ); });
447+ m.def (" svd_only_u_vt_c64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_c64 ); });
448+ m.def (" svd_only_u_vt_c128 " , []() { return EncapsulateFfiCall (svd_only_u_vt_c128 ); });
449+ m.def (" svd_only_u_vt_qr_f32 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_f32 ); });
450+ m.def (" svd_only_u_vt_qr_f64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_f64 ); });
451+ m.def (" svd_only_u_vt_qr_c64 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_c64 ); });
452+ m.def (" svd_only_u_vt_qr_c128 " , []() { return EncapsulateFfiCall (svd_only_u_vt_qr_c128 ); });
453453}
0 commit comments