11use crate :: dimension:: IntoDimension ;
22use crate :: Dimension ;
3- use crate :: { Shape , StrideShape } ;
3+
4+ /// A contiguous array shape of n dimensions.
5+ ///
6+ /// Either c- or f- memory ordered (*c* a.k.a *row major* is the default).
7+ #[ derive( Copy , Clone , Debug ) ]
8+ pub struct Shape < D > {
9+ /// Shape (axis lengths)
10+ pub ( crate ) dim : D ,
11+ /// Strides can only be C or F here
12+ pub ( crate ) strides : Strides < Contiguous > ,
13+ }
14+
15+ #[ derive( Copy , Clone , Debug ) ]
16+ pub ( crate ) enum Contiguous { }
17+
18+ impl < D > Shape < D > {
19+ pub ( crate ) fn is_c ( & self ) -> bool {
20+ matches ! ( self . strides, Strides :: C )
21+ }
22+ }
23+
24+
25+ /// An array shape of n dimensions in c-order, f-order or custom strides.
26+ #[ derive( Copy , Clone , Debug ) ]
27+ pub struct StrideShape < D > {
28+ pub ( crate ) dim : D ,
29+ pub ( crate ) strides : Strides < D > ,
30+ }
31+
32+ /// Stride description
33+ #[ derive( Copy , Clone , Debug ) ]
34+ pub ( crate ) enum Strides < D > {
35+ /// Row-major ("C"-order)
36+ C ,
37+ /// Column-major ("F"-order)
38+ F ,
39+ /// Custom strides
40+ Custom ( D )
41+ }
42+
43+ impl < D > Strides < D > {
44+ /// Return strides for `dim` (computed from dimension if c/f, else return the custom stride)
45+ pub ( crate ) fn strides_for_dim ( self , dim : & D ) -> D
46+ where D : Dimension
47+ {
48+ match self {
49+ Strides :: C => dim. default_strides ( ) ,
50+ Strides :: F => dim. fortran_strides ( ) ,
51+ Strides :: Custom ( c) => {
52+ debug_assert_eq ! ( c. ndim( ) , dim. ndim( ) ,
53+ "Custom strides given with {} dimensions, expected {}" ,
54+ c. ndim( ) , dim. ndim( ) ) ;
55+ c
56+ }
57+ }
58+ }
59+
60+ pub ( crate ) fn is_custom ( & self ) -> bool {
61+ matches ! ( * self , Strides :: Custom ( _) )
62+ }
63+ }
464
565/// A trait for `Shape` and `D where D: Dimension` that allows
666/// customizing the memory layout (strides) of an array shape.
@@ -34,36 +94,18 @@ where
3494{
3595 fn from ( value : T ) -> Self {
3696 let shape = value. into_shape ( ) ;
37- let d = shape. dim ;
38- let st = if shape. is_c {
39- d. default_strides ( )
97+ let st = if shape. is_c ( ) {
98+ Strides :: C
4099 } else {
41- d . fortran_strides ( )
100+ Strides :: F
42101 } ;
43102 StrideShape {
44103 strides : st,
45- dim : d,
46- custom : false ,
104+ dim : shape. dim ,
47105 }
48106 }
49107}
50108
51- /*
52- impl<D> From<Shape<D>> for StrideShape<D>
53- where D: Dimension
54- {
55- fn from(shape: Shape<D>) -> Self {
56- let d = shape.dim;
57- let st = if shape.is_c { d.default_strides() } else { d.fortran_strides() };
58- StrideShape {
59- strides: st,
60- dim: d,
61- custom: false,
62- }
63- }
64- }
65- */
66-
67109impl < T > ShapeBuilder for T
68110where
69111 T : IntoDimension ,
73115 fn into_shape ( self ) -> Shape < Self :: Dim > {
74116 Shape {
75117 dim : self . into_dimension ( ) ,
76- is_c : true ,
118+ strides : Strides :: C ,
77119 }
78120 }
79121 fn f ( self ) -> Shape < Self :: Dim > {
@@ -93,21 +135,24 @@ where
93135{
94136 type Dim = D ;
95137 type Strides = D ;
138+
96139 fn into_shape ( self ) -> Shape < D > {
97140 self
98141 }
142+
99143 fn f ( self ) -> Self {
100144 self . set_f ( true )
101145 }
146+
102147 fn set_f ( mut self , is_f : bool ) -> Self {
103- self . is_c = !is_f;
148+ self . strides = if !is_f { Strides :: C } else { Strides :: F } ;
104149 self
105150 }
151+
106152 fn strides ( self , st : D ) -> StrideShape < D > {
107153 StrideShape {
108154 dim : self . dim ,
109- strides : st,
110- custom : true ,
155+ strides : Strides :: Custom ( st) ,
111156 }
112157 }
113158}
0 commit comments