@@ -475,6 +475,7 @@ def _batched_qr(a, mode="reduced"):
475475 )
476476
477477
478+ # pylint: disable=too-many-locals
478479def _batched_svd (
479480 a ,
480481 uv_type ,
@@ -532,29 +533,30 @@ def _batched_svd(
532533 batch_shape_orig ,
533534 )
534535
535- k = min (m , n )
536- if compute_uv :
537- if full_matrices :
538- u_shape = (m , m ) + (batch_size ,)
539- vt_shape = (n , n ) + (batch_size ,)
540- jobu = ord ("A" )
541- jobvt = ord ("A" )
542- else :
543- u_shape = (m , k ) + (batch_size ,)
544- vt_shape = (k , n ) + (batch_size ,)
545- jobu = ord ("S" )
546- jobvt = ord ("S" )
536+ # Transpose if m < n:
537+ # 1. cuSolver gesvd supports only m >= n
538+ # 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
539+ if m < n :
540+ n , m = a .shape [- 2 :]
541+ trans_flag = True
547542 else :
548- u_shape = vt_shape = ()
549- jobu = ord ("N" )
550- jobvt = ord ("N" )
543+ trans_flag = False
544+
545+ u_shape , vt_shape , s_shape , jobu , jobvt = _get_svd_shapes_and_flags (
546+ m , n , compute_uv , full_matrices , batch_size = batch_size
547+ )
551548
552549 _manager = dpu .SequentialOrderManager [exec_q ]
553550 dep_evs = _manager .submitted_events
554551
555552 # Reorder the elements by moving the last two axes of `a` to the front
556553 # to match fortran-like array order which is assumed by gesvd.
557- a = dpnp .moveaxis (a , (- 2 , - 1 ), (0 , 1 ))
554+ if trans_flag :
555+ # Transpose axes for cuSolver and to optimize reduction
556+ # to bidiagonal form
557+ a = dpnp .moveaxis (a , (- 1 , - 2 ), (0 , 1 ))
558+ else :
559+ a = dpnp .moveaxis (a , (- 2 , - 1 ), (0 , 1 ))
558560
559561 # oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
560562 # as input.
@@ -583,7 +585,7 @@ def _batched_svd(
583585 sycl_queue = exec_q ,
584586 )
585587 s_h = dpnp .empty (
586- ( batch_size ,) + ( k ,) ,
588+ s_shape ,
587589 dtype = s_type ,
588590 order = "C" ,
589591 usm_type = usm_type ,
@@ -607,16 +609,23 @@ def _batched_svd(
607609 # gesvd call writes `u_h` and `vt_h` in Fortran order;
608610 # reorder the axes to match C order by moving the last axis
609611 # to the front
610- u = dpnp .moveaxis (u_h , - 1 , 0 )
611- vt = dpnp .moveaxis (vt_h , - 1 , 0 )
612+ if trans_flag :
613+ # Transpose axes to restore U and V^T for the original matrix
614+ u = dpnp .moveaxis (u_h , (0 , - 1 ), (- 1 , 0 ))
615+ vt = dpnp .moveaxis (vt_h , (0 , - 1 ), (- 1 , 0 ))
616+ else :
617+ u = dpnp .moveaxis (u_h , - 1 , 0 )
618+ vt = dpnp .moveaxis (vt_h , - 1 , 0 )
619+
612620 if a_ndim > 3 :
613621 u = u .reshape (batch_shape_orig + u .shape [- 2 :])
614622 vt = vt .reshape (batch_shape_orig + vt .shape [- 2 :])
615623 # dpnp.moveaxis can make the array non-contiguous if it is not 2D
616624 # Convert to contiguous to align with NumPy
617625 u = dpnp .ascontiguousarray (u )
618626 vt = dpnp .ascontiguousarray (vt )
619- return u , s , vt
627+ # Swap `u` and `vt` for transposed input to restore correct order
628+ return (vt , s , u ) if trans_flag else (u , s , vt )
620629 return s
621630
622631
@@ -759,6 +768,36 @@ def _common_inexact_type(default_dtype, *dtypes):
759768 return dpnp .result_type (* inexact_dtypes )
760769
761770
771+ def _get_svd_shapes_and_flags (m , n , compute_uv , full_matrices , batch_size = None ):
772+ """Return the shapes and flags for SVD computations."""
773+
774+ k = min (m , n )
775+ if compute_uv :
776+ if full_matrices :
777+ u_shape = (m , m )
778+ vt_shape = (n , n )
779+ jobu = ord ("A" )
780+ jobvt = ord ("A" )
781+ else :
782+ u_shape = (m , k )
783+ vt_shape = (k , n )
784+ jobu = ord ("S" )
785+ jobvt = ord ("S" )
786+ else :
787+ u_shape = vt_shape = ()
788+ jobu = ord ("N" )
789+ jobvt = ord ("N" )
790+
791+ s_shape = (k ,)
792+ if batch_size is not None :
793+ if compute_uv :
794+ u_shape += (batch_size ,)
795+ vt_shape += (batch_size ,)
796+ s_shape = (batch_size ,) + s_shape
797+
798+ return u_shape , vt_shape , s_shape , jobu , jobvt
799+
800+
762801def _hermitian_svd (a , compute_uv ):
763802 """
764803 _hermitian_svd(a, compute_uv)
@@ -2695,6 +2734,16 @@ def dpnp_svd(
26952734 a , uv_type , s_type , full_matrices , compute_uv , exec_q , usm_type
26962735 )
26972736
2737+ # Transpose if m < n:
2738+ # 1. cuSolver gesvd supports only m >= n
2739+ # 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
2740+ if m < n :
2741+ n , m = a .shape
2742+ a = a .transpose ()
2743+ trans_flag = True
2744+ else :
2745+ trans_flag = False
2746+
26982747 # oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
26992748 # Allocate 'F' order memory for dpnp arrays to comply with
27002749 # these requirements.
@@ -2716,22 +2765,9 @@ def dpnp_svd(
27162765 )
27172766 _manager .add_event_pair (ht_ev , copy_ev )
27182767
2719- k = min (m , n )
2720- if compute_uv :
2721- if full_matrices :
2722- u_shape = (m , m )
2723- vt_shape = (n , n )
2724- jobu = ord ("A" )
2725- jobvt = ord ("A" )
2726- else :
2727- u_shape = (m , k )
2728- vt_shape = (k , n )
2729- jobu = ord ("S" )
2730- jobvt = ord ("S" )
2731- else :
2732- u_shape = vt_shape = ()
2733- jobu = ord ("N" )
2734- jobvt = ord ("N" )
2768+ u_shape , vt_shape , s_shape , jobu , jobvt = _get_svd_shapes_and_flags (
2769+ m , n , compute_uv , full_matrices
2770+ )
27352771
27362772 # oneMKL LAPACK assumes fortran-like array as input.
27372773 # Allocate 'F' order memory for dpnp output arrays to comply with
@@ -2746,7 +2782,7 @@ def dpnp_svd(
27462782 shape = vt_shape ,
27472783 order = "F" ,
27482784 )
2749- s_h = dpnp .empty_like (a_h , shape = ( k ,) , dtype = s_type )
2785+ s_h = dpnp .empty_like (a_h , shape = s_shape , dtype = s_type )
27502786
27512787 ht_ev , gesvd_ev = li ._gesvd (
27522788 exec_q ,
@@ -2761,6 +2797,11 @@ def dpnp_svd(
27612797 _manager .add_event_pair (ht_ev , gesvd_ev )
27622798
27632799 if compute_uv :
2800+ # Transposing the input matrix swaps the roles of U and Vt:
2801+ # For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
2802+ # Transpose and swap them back to restore correct order for A.
2803+ if trans_flag :
2804+ return vt_h .T , s_h , u_h .T
27642805 # gesvd call writes `u_h` and `vt_h` in Fortran order;
27652806 # Convert to contiguous to align with NumPy
27662807 u_h = dpnp .ascontiguousarray (u_h )
0 commit comments