1- #include " lapack.h"
21#include " nanobind/nanobind.h"
32#include " xla/ffi/api/ffi.h"
43
4+ using lapack_int = int ;
5+
56namespace nb = nanobind;
67using namespace ::xla;
78
8- nb::module_ cython_lapack = nb::module_::import_(" scipy.linalg.cython_lapack" );
9-
109inline constexpr auto LapackIntDtype = ffi::DataType::S32;
1110static_assert (std::is_same_v<::xla::ffi::NativeType<LapackIntDtype>, lapack_int>);
1211
@@ -30,8 +29,8 @@ static ffi::Error SvdOnlyVtImpl(
3029 MachineType* work, lapack_int const * lwork,
3130 RealType* rwork,
3231 lapack_int* iwork,
33- lapack_int* info,
34- size_t strlen ),
32+ lapack_int* info
33+ ),
3534 void (char const * jobz,
3635 lapack_int const * m, lapack_int const * n,
3736 MachineType* A, lapack_int const * lda,
@@ -40,22 +39,39 @@ static ffi::Error SvdOnlyVtImpl(
4039 MachineType* VT, lapack_int const * ldvt,
4140 MachineType* work, lapack_int const * lwork,
4241 lapack_int* iwork,
43- lapack_int* info,
44- size_t strlen )>;
42+ lapack_int* info
43+ )>;
4544
4645 FnSig* fn = nullptr ;
4746
48- if constexpr (dtype == ffi::DataType::F32) {
49- fn = sgesdd_;
50- }
51- if constexpr (dtype == ffi::DataType::F64) {
52- fn = dgesdd_;
53- }
54- if constexpr (dtype == ffi::DataType::C64) {
55- fn = reinterpret_cast <FnSig*>(cgesdd_);
56- }
57- if constexpr (dtype == ffi::DataType::C128) {
58- fn = reinterpret_cast <FnSig*>(zgesdd_);
47+ try {
48+ PyGILState_STATE state = PyGILState_Ensure ();
49+
50+ nb::module_ cython_lapack = nb::module_::import_ (" scipy.linalg.cython_lapack" );
51+
52+ nb::dict lapack_capi = cython_lapack.attr (" __pyx_capi__" );
53+
54+ auto get_lapack_ptr = [&](const char * name) {
55+ return nb::cast<nb::capsule>(lapack_capi[name]).data ();
56+ };
57+
58+ if constexpr (dtype == ffi::DataType::F32) {
59+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" sgesdd" ));
60+ }
61+ if constexpr (dtype == ffi::DataType::F64) {
62+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" dgesdd" ));
63+ }
64+ if constexpr (dtype == ffi::DataType::C64) {
65+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" cgesdd" ));
66+ }
67+ if constexpr (dtype == ffi::DataType::C128) {
68+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" zgesdd" ));
69+ }
70+
71+ PyGILState_Release (state);
72+ } catch (const nb::python_error &e) {
73+ std::cerr << e.what () << std::endl;
74+ throw ;
5975 }
6076
6177 const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
@@ -96,12 +112,14 @@ static ffi::Error SvdOnlyVtImpl(
96112 fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
97113 &x_rows_lapack, nullptr , nullptr ,
98114 &ldu, nullptr , &x_cols_lapack, &work_size,
99- &lwork, nullptr , nullptr , info_data, 1 );
115+ &lwork, nullptr , nullptr , info_data
116+ );
100117 } else {
101118 fn (&jobz, &x_rows_lapack, &x_cols_lapack, nullptr ,
102119 &x_rows_lapack, nullptr , nullptr ,
103120 &ldu, nullptr , &x_cols_lapack,
104- &work_size, &lwork, nullptr , info_data, 1 );
121+ &work_size, &lwork, nullptr , info_data
122+ );
105123 }
106124
107125 if (*info_data != 0 ) {
@@ -131,12 +149,14 @@ static ffi::Error SvdOnlyVtImpl(
131149 fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
132150 &x_rows_lapack, s_data, nullptr ,
133151 &ldu, vt_data, &x_cols_lapack, work.get (),
134- &lwork, rwork.get (), iwork.get (), info_data, 1 );
152+ &lwork, rwork.get (), iwork.get (), info_data
153+ );
135154 } else {
136155 fn (&jobz, &x_rows_lapack, &x_cols_lapack, x_out_data,
137156 &x_rows_lapack, s_data, nullptr ,
138157 &ldu, vt_data, &x_cols_lapack,
139- work.get (), &lwork, iwork.get (), info_data, 1 );
158+ work.get (), &lwork, iwork.get (), info_data
159+ );
140160 }
141161
142162 if (*info_data != 0 ) {
@@ -164,31 +184,48 @@ static ffi::Error SvdOnlyVtQRImpl(
164184 MachineType* VT, lapack_int const * ldvt,
165185 MachineType* work, lapack_int const * lwork,
166186 RealType* rwork,
167- lapack_int* info,
168- size_t strlen1, size_t strlen2 ),
187+ lapack_int* info
188+ ),
169189 void (char const * jobu, char const * jobvt,
170190 lapack_int const * m, lapack_int const * n,
171191 MachineType* A, lapack_int const * lda,
172192 RealType* S,
173193 MachineType* U, lapack_int const * ldu,
174194 MachineType* VT, lapack_int const * ldvt,
175195 MachineType* work, lapack_int const * lwork,
176- lapack_int* info,
177- size_t strlen1, size_t strlen2 )>;
196+ lapack_int* info
197+ )>;
178198
179199 FnSig* fn = nullptr ;
180200
181- if constexpr (dtype == ffi::DataType::F32) {
182- fn = sgesvd_;
183- }
184- if constexpr (dtype == ffi::DataType::F64) {
185- fn = dgesvd_;
186- }
187- if constexpr (dtype == ffi::DataType::C64) {
188- fn = reinterpret_cast <FnSig*>(cgesvd_);
189- }
190- if constexpr (dtype == ffi::DataType::C128) {
191- fn = reinterpret_cast <FnSig*>(zgesvd_);
201+ try {
202+ PyGILState_STATE state = PyGILState_Ensure ();
203+
204+ nb::module_ cython_lapack = nb::module_::import_ (" scipy.linalg.cython_lapack" );
205+
206+ nb::dict lapack_capi = cython_lapack.attr (" __pyx_capi__" );
207+
208+ auto get_lapack_ptr = [&](const char * name) {
209+ return nb::cast<nb::capsule>(lapack_capi[name]).data ();
210+ };
211+
212+ if constexpr (dtype == ffi::DataType::F32) {
213+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" sgesvd" ));
214+ }
215+ if constexpr (dtype == ffi::DataType::F64) {
216+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" dgesvd" ));
217+ }
218+ if constexpr (dtype == ffi::DataType::C64) {
219+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" cgesvd" ));
220+ }
221+ if constexpr (dtype == ffi::DataType::C128) {
222+ fn = reinterpret_cast <FnSig*>(get_lapack_ptr (" zgesvd" ));
223+ }
224+
225+ PyGILState_Release (state);
226+ } catch (const nb::python_error &e) {
227+ std::cerr << e.what () << std::endl;
228+ throw ;
192229 }
193230
194231 const auto lapack_int_max = std::numeric_limits<lapack_int>::max ();
@@ -230,12 +267,14 @@ static ffi::Error SvdOnlyVtQRImpl(
230267 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
231268 &x_rows_lapack, nullptr , nullptr ,
232269 &ldu, nullptr , &x_cols_lapack, &work_size,
233- &lwork, nullptr , info_data, 1 , 1 );
270+ &lwork, nullptr , info_data
271+ );
234272 } else {
235273 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, nullptr ,
236274 &x_rows_lapack, nullptr , nullptr ,
237275 &ldu, nullptr , &x_cols_lapack,
238- &work_size, &lwork, info_data, 1 , 1 );
276+ &work_size, &lwork, info_data
277+ );
239278 }
240279
241280 if (*info_data != 0 ) {
@@ -262,12 +301,14 @@ static ffi::Error SvdOnlyVtQRImpl(
262301 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
263302 &x_rows_lapack, s_data, nullptr ,
264303 &ldu, nullptr , &x_cols_lapack, work.get (),
265- &lwork, rwork.get (), info_data, 1 , 1 );
304+ &lwork, rwork.get (), info_data
305+ );
266306 } else {
267307 fn (&jobu, &jobvt, &x_rows_lapack, &x_cols_lapack, x_out_data,
268308 &x_rows_lapack, s_data, nullptr ,
269309 &ldu, nullptr , &x_cols_lapack,
270- work.get (), &lwork, info_data, 1 , 1 );
310+ work.get (), &lwork, info_data
311+ );
271312 }
272313
273314 if (*info_data != 0 ) {
@@ -333,12 +374,12 @@ nb::capsule EncapsulateFfiCall(T *fn) {
333374}
334375
335376NB_MODULE (_svd_only_vt, m) {
336- m.def (" svd_only_vt_f32" , []() { return EncapsulateFfiCall (svd_only_vt_f32); });
337- m.def (" svd_only_vt_f64" , []() { return EncapsulateFfiCall (svd_only_vt_f64); });
338- m.def (" svd_only_vt_c64" , []() { return EncapsulateFfiCall (svd_only_vt_c64); });
339- m.def (" svd_only_vt_c128" , []() { return EncapsulateFfiCall (svd_only_vt_c128); });
340- m.def (" svd_only_vt_qr_f32" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f32); });
341- m.def (" svd_only_vt_qr_f64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f64); });
342- m.def (" svd_only_vt_qr_c64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c64); });
343- m.def (" svd_only_vt_qr_c128" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c128); });
377+ m.def (" svd_only_vt_f32" , []() { return EncapsulateFfiCall (svd_only_vt_f32); }, nb::call_guard<nb::gil_scoped_release>() );
378+ m.def (" svd_only_vt_f64" , []() { return EncapsulateFfiCall (svd_only_vt_f64); }, nb::call_guard<nb::gil_scoped_release>() );
379+ m.def (" svd_only_vt_c64" , []() { return EncapsulateFfiCall (svd_only_vt_c64); }, nb::call_guard<nb::gil_scoped_release>() );
380+ m.def (" svd_only_vt_c128" , []() { return EncapsulateFfiCall (svd_only_vt_c128); }, nb::call_guard<nb::gil_scoped_release>() );
381+ m.def (" svd_only_vt_qr_f32" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f32); }, nb::call_guard<nb::gil_scoped_release>() );
382+ m.def (" svd_only_vt_qr_f64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_f64); }, nb::call_guard<nb::gil_scoped_release>() );
383+ m.def (" svd_only_vt_qr_c64" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c64); }, nb::call_guard<nb::gil_scoped_release>() );
384+ m.def (" svd_only_vt_qr_c128" , []() { return EncapsulateFfiCall (svd_only_vt_qr_c128); }, nb::call_guard<nb::gil_scoped_release>() );
344385}
0 commit comments