1- // Copyright 2014-2016 bluss and ndarray developers.
1+ // Copyright 2014-2020 bluss and ndarray developers.
22//
33// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
44// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
@@ -325,9 +325,9 @@ where
325325
326326 // Avoid initializing the memory in vec -- set it during iteration
327327 unsafe {
328- let mut c = Array :: uninitialized ( m) ;
329- general_mat_vec_mul ( A :: one ( ) , self , rhs, A :: zero ( ) , & mut c ) ;
330- c
328+ let mut c = Array1 :: maybe_uninit ( m) ;
329+ general_mat_vec_mul_impl ( A :: one ( ) , self , rhs, A :: zero ( ) , c . raw_view_mut ( ) . cast :: < A > ( ) ) ;
330+ c. assume_init ( )
331331 }
332332 }
333333}
@@ -598,6 +598,30 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
598598 S2 : Data < Elem = A > ,
599599 S3 : DataMut < Elem = A > ,
600600 A : LinalgScalar ,
601+ {
602+ unsafe {
603+ general_mat_vec_mul_impl ( alpha, a, x, beta, y. raw_view_mut ( ) )
604+ }
605+ }
606+
607+ /// General matrix-vector multiplication
608+ ///
609+ /// Use a raw view for the destination vector, so that it can be uninitalized.
610+ ///
611+ /// ## Safety
612+ ///
613+ /// The caller must ensure that the raw view is valid for writing.
614+ /// the destination may be uninitialized iff beta is zero.
615+ unsafe fn general_mat_vec_mul_impl < A , S1 , S2 > (
616+ alpha : A ,
617+ a : & ArrayBase < S1 , Ix2 > ,
618+ x : & ArrayBase < S2 , Ix1 > ,
619+ beta : A ,
620+ y : RawArrayViewMut < A , Ix1 > ,
621+ ) where
622+ S1 : Data < Elem = A > ,
623+ S2 : Data < Elem = A > ,
624+ A : LinalgScalar ,
601625{
602626 let ( ( m, k) , k2) = ( a. dim ( ) , x. dim ( ) ) ;
603627 let m2 = y. dim ( ) ;
@@ -626,22 +650,20 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
626650 let x_stride = x. strides( ) [ 0 ] as blas_index;
627651 let y_stride = y. strides( ) [ 0 ] as blas_index;
628652
629- unsafe {
630- blas_sys:: $gemv(
631- layout,
632- a_trans,
633- m as blas_index, // m, rows of Op(a)
634- k as blas_index, // n, cols of Op(a)
635- cast_as( & alpha) , // alpha
636- a. ptr. as_ptr( ) as * const _, // a
637- a_stride, // lda
638- x. ptr. as_ptr( ) as * const _, // x
639- x_stride,
640- cast_as( & beta) , // beta
641- y. ptr. as_ptr( ) as * mut _, // x
642- y_stride,
643- ) ;
644- }
653+ blas_sys:: $gemv(
654+ layout,
655+ a_trans,
656+ m as blas_index, // m, rows of Op(a)
657+ k as blas_index, // n, cols of Op(a)
658+ cast_as( & alpha) , // alpha
659+ a. ptr. as_ptr( ) as * const _, // a
660+ a_stride, // lda
661+ x. ptr. as_ptr( ) as * const _, // x
662+ x_stride,
663+ cast_as( & beta) , // beta
664+ y. ptr. as_ptr( ) as * mut _, // x
665+ y_stride,
666+ ) ;
645667 return ;
646668 }
647669 }
@@ -655,8 +677,9 @@ pub fn general_mat_vec_mul<A, S1, S2, S3>(
655677 /* general */
656678
657679 if beta. is_zero ( ) {
680+ // when beta is zero, c may be uninitialized
658681 Zip :: from ( a. outer_iter ( ) ) . and ( y) . apply ( |row, elt| {
659- * elt = row. dot ( x) * alpha;
682+ elt. write ( row. dot ( x) * alpha) ;
660683 } ) ;
661684 } else {
662685 Zip :: from ( a. outer_iter ( ) ) . and ( y) . apply ( |row, elt| {
@@ -683,7 +706,7 @@ fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
683706#[ cfg( feature = "blas" ) ]
684707fn blas_compat_1d < A , S > ( a : & ArrayBase < S , Ix1 > ) -> bool
685708where
686- S : Data ,
709+ S : RawData ,
687710 A : ' static ,
688711 S :: Elem : ' static ,
689712{
0 commit comments