@@ -53,17 +53,26 @@ where
5353 D : Dimension ,
5454{
5555 pub ( crate ) fn layout_impl ( & self ) -> Layout {
56- Layout :: new ( if self . is_standard_layout ( ) {
57- if self . ndim ( ) <= 1 {
58- FORDER | CORDER
56+ let n = self . ndim ( ) ;
57+ if self . is_standard_layout ( ) {
58+ if n <= 1 {
59+ Layout :: one_dimensional ( )
5960 } else {
60- CORDER
61+ Layout :: c ( )
62+ }
63+ } else if n > 1 && self . raw_view ( ) . reversed_axes ( ) . is_standard_layout ( ) {
64+ Layout :: f ( )
65+ } else if n > 1 {
66+ if self . stride_of ( Axis ( 0 ) ) == 1 {
67+ Layout :: fpref ( )
68+ } else if self . stride_of ( Axis ( n - 1 ) ) == 1 {
69+ Layout :: cpref ( )
70+ } else {
71+ Layout :: none ( )
6172 }
62- } else if self . ndim ( ) > 1 && self . raw_view ( ) . reversed_axes ( ) . is_standard_layout ( ) {
63- FORDER
6473 } else {
65- 0
66- } )
74+ Layout :: none ( )
75+ }
6776 }
6877}
6978
@@ -587,6 +596,9 @@ pub struct Zip<Parts, D> {
587596 parts : Parts ,
588597 dimension : D ,
589598 layout : Layout ,
599+ /// The sum of the layout tendencies of the parts;
600+ /// positive for c- and negative for f-layout preference.
601+ layout_tendency : i32 ,
590602}
591603
592604
@@ -605,10 +617,12 @@ where
605617 {
606618 let array = p. into_producer ( ) ;
607619 let dim = array. raw_dim ( ) ;
620+ let layout = array. layout ( ) ;
608621 Zip {
609622 dimension : dim,
610- layout : array . layout ( ) ,
623+ layout,
611624 parts : ( array, ) ,
625+ layout_tendency : layout. tendency ( ) ,
612626 }
613627 }
614628}
@@ -661,24 +675,29 @@ where
661675 self . dimension [ axis. index ( ) ]
662676 }
663677
678+ fn prefer_f ( & self ) -> bool {
679+ !self . layout . is ( CORDER ) && ( self . layout . is ( FORDER ) || self . layout_tendency < 0 )
680+ }
681+
664682 /// Return an *approximation* to the max stride axis; if
665683 /// component arrays disagree, there may be no choice better than the
666684 /// others.
667685 fn max_stride_axis ( & self ) -> Axis {
668- let i = match self . layout . flag ( ) {
669- FORDER => self
686+ let i = if self . prefer_f ( ) {
687+ self
670688 . dimension
671689 . slice ( )
672690 . iter ( )
673691 . rposition ( |& len| len > 1 )
674- . unwrap_or ( self . dimension . ndim ( ) - 1 ) ,
692+ . unwrap_or ( self . dimension . ndim ( ) - 1 )
693+ } else {
675694 /* corder or default */
676- _ => self
695+ self
677696 . dimension
678697 . slice ( )
679698 . iter ( )
680699 . position ( |& len| len > 1 )
681- . unwrap_or ( 0 ) ,
700+ . unwrap_or ( 0 )
682701 } ;
683702 Axis ( i)
684703 }
@@ -699,6 +718,7 @@ where
699718 self . apply_core_strided ( acc, function)
700719 }
701720 }
721+
702722 fn apply_core_contiguous < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
703723 where
704724 F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
@@ -717,7 +737,7 @@ where
717737 FoldWhile :: Continue ( acc)
718738 }
719739
720- fn apply_core_strided < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
740+ fn apply_core_strided < F , Acc > ( & mut self , acc : Acc , function : F ) -> FoldWhile < Acc >
721741 where
722742 F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
723743 P : ZippableTuple < Dim = D > ,
@@ -726,13 +746,27 @@ where
726746 if n == 0 {
727747 panic ! ( "Unreachable: ndim == 0 is contiguous" )
728748 }
749+ if n == 1 || self . layout_tendency >= 0 {
750+ self . apply_core_strided_c ( acc, function)
751+ } else {
752+ self . apply_core_strided_f ( acc, function)
753+ }
754+ }
755+
756+ // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
757+ fn apply_core_strided_c < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
758+ where
759+ F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
760+ P : ZippableTuple < Dim = D > ,
761+ {
762+ let n = self . dimension . ndim ( ) ;
729763 let unroll_axis = n - 1 ;
730764 let inner_len = self . dimension [ unroll_axis] ;
731765 self . dimension [ unroll_axis] = 1 ;
732766 let mut index_ = self . dimension . first_index ( ) ;
733767 let inner_strides = self . parts . stride_of ( unroll_axis) ;
768+ // Loop unrolled over closest axis
734769 while let Some ( index) = index_ {
735- // Let's “unroll” the loop over the innermost axis
736770 unsafe {
737771 let ptr = self . parts . uget_ptr ( & index) ;
738772 for i in 0 ..inner_len {
@@ -747,9 +781,40 @@ where
747781 FoldWhile :: Continue ( acc)
748782 }
749783
784+ // Non-contiguous but preference for F - unroll over Axis(0)
785+ fn apply_core_strided_f < F , Acc > ( & mut self , mut acc : Acc , mut function : F ) -> FoldWhile < Acc >
786+ where
787+ F : FnMut ( Acc , P :: Item ) -> FoldWhile < Acc > ,
788+ P : ZippableTuple < Dim = D > ,
789+ {
790+ let unroll_axis = 0 ;
791+ let inner_len = self . dimension [ unroll_axis] ;
792+ self . dimension [ unroll_axis] = 1 ;
793+ let index_ = self . dimension . first_index ( ) ;
794+ let inner_strides = self . parts . stride_of ( unroll_axis) ;
795+ // Loop unrolled over closest axis
796+ if let Some ( mut index) = index_ {
797+ loop {
798+ unsafe {
799+ let ptr = self . parts . uget_ptr ( & index) ;
800+ for i in 0 ..inner_len {
801+ let p = ptr. stride_offset ( inner_strides, i) ;
802+ acc = fold_while ! ( function( acc, self . parts. as_ref( p) ) ) ;
803+ }
804+ }
805+
806+ if !self . dimension . next_for_f ( & mut index) {
807+ break ;
808+ }
809+ }
810+ }
811+ self . dimension [ unroll_axis] = inner_len;
812+ FoldWhile :: Continue ( acc)
813+ }
814+
750815 pub ( crate ) fn uninitalized_for_current_layout < T > ( & self ) -> Array < MaybeUninit < T > , D >
751816 {
752- let is_f = ! self . layout . is ( CORDER ) && self . layout . is ( FORDER ) ;
817+ let is_f = self . prefer_f ( ) ;
753818 Array :: maybe_uninit ( self . dimension . clone ( ) . set_f ( is_f) )
754819 }
755820}
@@ -995,8 +1060,9 @@ macro_rules! map_impl {
9951060 let ( $( $p, ) * ) = self . parts;
9961061 Zip {
9971062 parts: ( $( $p, ) * part, ) ,
998- layout: self . layout. and ( part_layout) ,
1063+ layout: self . layout. intersect ( part_layout) ,
9991064 dimension: self . dimension,
1065+ layout_tendency: self . layout_tendency + part_layout. tendency( ) ,
10001066 }
10011067 }
10021068
@@ -1052,11 +1118,13 @@ macro_rules! map_impl {
10521118 dimension: d1,
10531119 layout: self . layout,
10541120 parts: p1,
1121+ layout_tendency: self . layout_tendency,
10551122 } ,
10561123 Zip {
10571124 dimension: d2,
10581125 layout: self . layout,
10591126 parts: p2,
1127+ layout_tendency: self . layout_tendency,
10601128 } )
10611129 }
10621130 }
0 commit comments