@@ -4,7 +4,7 @@ use crate::{
44 context:: CublasContext ,
55 error:: { Error , ToResult } ,
66 raw:: { ComplexLevel1 , FloatLevel1 , Level1 } ,
7- BlasDatatype ,
7+ BlasDatatype , Float ,
88} ;
99use cust:: memory:: { GpuBox , GpuBuffer } ;
1010use cust:: stream:: Stream ;
@@ -641,4 +641,172 @@ impl CublasContext {
641641 ) -> Result {
642642 self . rot_strided ( stream, n, x, None , y, None , c, s)
643643 }
644+
645+ /// Constructs the givens rotation matrix that zeros out the second entry of a 2x1 vector.
646+ pub fn rotg < T : Level1 > (
647+ & mut self ,
648+ stream : & Stream ,
649+ a : & mut impl GpuBox < T > ,
650+ b : & mut impl GpuBox < T > ,
651+ c : & mut impl GpuBox < T :: FloatTy > ,
652+ s : & mut impl GpuBox < T > ,
653+ ) -> Result {
654+ self . with_stream ( stream, |ctx| unsafe {
655+ Ok ( T :: rotg (
656+ ctx. raw ,
657+ a. as_device_ptr ( ) . as_mut_ptr ( ) ,
658+ b. as_device_ptr ( ) . as_mut_ptr ( ) ,
659+ c. as_device_ptr ( ) . as_mut_ptr ( ) ,
660+ s. as_device_ptr ( ) . as_mut_ptr ( ) ,
661+ )
662+ . to_result ( ) ?)
663+ } )
664+ }
665+
666+ /// Same as [`CublasContext::rotm`] but with an explicit stride.
667+ pub fn rotm_strided < T : Level1 + Float > (
668+ & mut self ,
669+ stream : & Stream ,
670+ n : usize ,
671+ x : & mut impl GpuBuffer < T > ,
672+ x_stride : Option < usize > ,
673+ y : & mut impl GpuBuffer < T > ,
674+ y_stride : Option < usize > ,
675+ param : & impl GpuBox < T :: FloatTy > ,
676+ ) -> Result {
677+ check_stride ( x, n, x_stride) ;
678+ check_stride ( y, n, y_stride) ;
679+
680+ self . with_stream ( stream, |ctx| unsafe {
681+ Ok ( T :: rotm (
682+ ctx. raw ,
683+ n as i32 ,
684+ x. as_device_ptr ( ) . as_mut_ptr ( ) ,
685+ x_stride. unwrap_or ( 1 ) as i32 ,
686+ y. as_device_ptr ( ) . as_mut_ptr ( ) ,
687+ y_stride. unwrap_or ( 1 ) as i32 ,
688+ param. as_device_ptr ( ) . as_ptr ( ) ,
689+ )
690+ . to_result ( ) ?)
691+ } )
692+ }
693+
694+ /// Applies the modified givens transformation to vectors `x` and `y`.
695+ pub fn rotm < T : Level1 + Float > (
696+ & mut self ,
697+ stream : & Stream ,
698+ n : usize ,
699+ x : & mut impl GpuBuffer < T > ,
700+ y : & mut impl GpuBuffer < T > ,
701+ param : & impl GpuBox < T :: FloatTy > ,
702+ ) -> Result {
703+ self . rotm_strided ( stream, n, x, None , y, None , param)
704+ }
705+
706+ /// Same as [`CublasContext::rotmg`] but with an explicit stride.
707+ pub fn rotmg_strided < T : Level1 + Float > (
708+ & mut self ,
709+ stream : & Stream ,
710+ d1 : & mut impl GpuBox < T > ,
711+ d2 : & mut impl GpuBox < T > ,
712+ x1 : & mut impl GpuBox < T > ,
713+ y1 : & mut impl GpuBox < T > ,
714+ param : & mut impl GpuBox < T > ,
715+ ) -> Result {
716+ self . with_stream ( stream, |ctx| unsafe {
717+ Ok ( T :: rotmg (
718+ ctx. raw ,
719+ d1. as_device_ptr ( ) . as_mut_ptr ( ) ,
720+ d2. as_device_ptr ( ) . as_mut_ptr ( ) ,
721+ x1. as_device_ptr ( ) . as_mut_ptr ( ) ,
722+ y1. as_device_ptr ( ) . as_ptr ( ) ,
723+ param. as_device_ptr ( ) . as_mut_ptr ( ) ,
724+ )
725+ . to_result ( ) ?)
726+ } )
727+ }
728+
729+ /// Constructs the modified givens transformation that zeros out the second entry of a 2x1 vector.
730+ pub fn rotmg < T : Level1 + Float > (
731+ & mut self ,
732+ stream : & Stream ,
733+ d1 : & mut impl GpuBox < T > ,
734+ d2 : & mut impl GpuBox < T > ,
735+ x1 : & mut impl GpuBox < T > ,
736+ y1 : & mut impl GpuBox < T > ,
737+ param : & mut impl GpuBox < T > ,
738+ ) -> Result {
739+ self . rotmg_strided ( stream, d1, d2, x1, y1, param)
740+ }
741+
742+ /// Same as [`CublasContext::scal`] but with an explicit stride.
743+ pub fn scal_strided < T : Level1 > (
744+ & mut self ,
745+ stream : & Stream ,
746+ n : usize ,
747+ alpha : & impl GpuBox < T > ,
748+ x : & mut impl GpuBuffer < T > ,
749+ x_stride : Option < usize > ,
750+ ) -> Result {
751+ check_stride ( x, n, x_stride) ;
752+
753+ self . with_stream ( stream, |ctx| unsafe {
754+ Ok ( T :: scal (
755+ ctx. raw ,
756+ n as i32 ,
757+ alpha. as_device_ptr ( ) . as_ptr ( ) ,
758+ x. as_device_ptr ( ) . as_mut_ptr ( ) ,
759+ x_stride. unwrap_or ( 1 ) as i32 ,
760+ )
761+ . to_result ( ) ?)
762+ } )
763+ }
764+
765+ /// Scales vector `x` by `alpha` and overrides it with the result.
766+ pub fn scal < T : Level1 > (
767+ & mut self ,
768+ stream : & Stream ,
769+ n : usize ,
770+ alpha : & impl GpuBox < T > ,
771+ x : & mut impl GpuBuffer < T > ,
772+ ) -> Result {
773+ self . scal_strided ( stream, n, alpha, x, None )
774+ }
775+
776+ /// Same as [`CublasContext::swap`] but with an explicit stride.
777+ pub fn swap_strided < T : Level1 > (
778+ & mut self ,
779+ stream : & Stream ,
780+ n : usize ,
781+ x : & mut impl GpuBuffer < T > ,
782+ x_stride : Option < usize > ,
783+ y : & mut impl GpuBuffer < T > ,
784+ y_stride : Option < usize > ,
785+ ) -> Result {
786+ check_stride ( x, n, x_stride) ;
787+ check_stride ( y, n, y_stride) ;
788+
789+ self . with_stream ( stream, |ctx| unsafe {
790+ Ok ( T :: swap (
791+ ctx. raw ,
792+ n as i32 ,
793+ x. as_device_ptr ( ) . as_mut_ptr ( ) ,
794+ x_stride. unwrap_or ( 1 ) as i32 ,
795+ y. as_device_ptr ( ) . as_mut_ptr ( ) ,
796+ y_stride. unwrap_or ( 1 ) as i32 ,
797+ )
798+ . to_result ( ) ?)
799+ } )
800+ }
801+
802+ /// Swaps vectors `x` and `y`.
803+ pub fn swap < T : Level1 > (
804+ & mut self ,
805+ stream : & Stream ,
806+ n : usize ,
807+ x : & mut impl GpuBuffer < T > ,
808+ y : & mut impl GpuBuffer < T > ,
809+ ) -> Result {
810+ self . swap_strided ( stream, n, x, None , y, None )
811+ }
644812}
0 commit comments