@@ -13,7 +13,7 @@ pub struct Shape<D> {
1313}
1414
1515#[ derive( Copy , Clone , Debug ) ]
16- pub ( crate ) enum Contiguous { }
16+ pub enum Contiguous { }
1717
1818impl < D > Shape < D > {
1919 pub ( crate ) fn is_c ( & self ) -> bool {
@@ -31,7 +31,7 @@ pub struct StrideShape<D> {
3131
3232/// Stride description
3333#[ derive( Copy , Clone , Debug ) ]
34- pub ( crate ) enum Strides < D > {
34+ pub enum Strides < D > {
3535 /// Row-major ("C"-order)
3636 C ,
3737 /// Column-major ("F"-order)
@@ -168,3 +168,63 @@ where
168168 self . dim . size ( )
169169 }
170170}
171+
172+
173+ use crate :: order:: Order ;
174+
175+ pub trait ShapeArg {
176+ type Dim : Dimension ;
177+ type StrideType ;
178+
179+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) ;
180+ fn into_shape_and_strides ( self , default : Order ) -> ( Self :: Dim , Strides < Self :: StrideType > ) ;
181+ }
182+
183+ impl < T > ShapeArg for T where T : IntoDimension {
184+ type Dim = T :: Dim ;
185+ type StrideType = Contiguous ;
186+
187+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) {
188+ ( self . into_dimension ( ) , default)
189+ }
190+
191+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
192+ unimplemented ! ( )
193+ }
194+ }
195+
196+ impl < T > ShapeArg for ( T , Order ) where T : IntoDimension {
197+ type Dim = T :: Dim ;
198+ type StrideType = Contiguous ;
199+
200+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
201+ ( self . 0 . into_dimension ( ) , self . 1 )
202+ }
203+
204+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
205+ unimplemented ! ( )
206+ }
207+ }
208+
209+ /// Custom strides
210+ #[ derive( Copy , Clone , Debug ) ]
211+ pub struct CustomStrides < D > { strides : D }
212+
213+ // newtype constructor without public field
214+ #[ allow( non_snake_case) ]
215+ pub fn CustomStrides < T > ( strides : T ) -> CustomStrides < T > {
216+ CustomStrides { strides }
217+ }
218+
219+ impl < T > ShapeArg for ( T , CustomStrides < T > ) where T : IntoDimension {
220+ type Dim = T :: Dim ;
221+ type StrideType = T :: Dim ;
222+
223+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
224+ ( self . 0 . into_dimension ( ) , _default)
225+ }
226+
227+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < T :: Dim > ) {
228+ ( self . 0 . into_dimension ( ) , Strides :: Custom ( self . 1 . strides . into_dimension ( ) ) )
229+ }
230+ }
0 commit comments