@@ -13,7 +13,7 @@ pub struct Shape<D> {
1313}
1414
1515#[ derive( Copy , Clone , Debug ) ]
16- pub ( crate ) enum Contiguous { }
16+ pub enum Contiguous { }
1717
1818impl < D > Shape < D > {
1919 pub ( crate ) fn is_c ( & self ) -> bool {
4444
4545/// Stride description
4646#[ derive( Copy , Clone , Debug ) ]
47- pub ( crate ) enum Strides < D > {
47+ pub enum Strides < D > {
4848 /// Row-major ("C"-order)
4949 C ,
5050 /// Column-major ("F"-order)
@@ -184,3 +184,63 @@ where
184184 self . dim . size ( )
185185 }
186186}
187+
188+
189+ use crate :: order:: Order ;
190+
191+ pub trait ShapeArg {
192+ type Dim : Dimension ;
193+ type StrideType ;
194+
195+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) ;
196+ fn into_shape_and_strides ( self , default : Order ) -> ( Self :: Dim , Strides < Self :: StrideType > ) ;
197+ }
198+
199+ impl < T > ShapeArg for T where T : IntoDimension {
200+ type Dim = T :: Dim ;
201+ type StrideType = Contiguous ;
202+
203+ fn into_shape_and_order ( self , default : Order ) -> ( Self :: Dim , Order ) {
204+ ( self . into_dimension ( ) , default)
205+ }
206+
207+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
208+ unimplemented ! ( )
209+ }
210+ }
211+
212+ impl < T > ShapeArg for ( T , Order ) where T : IntoDimension {
213+ type Dim = T :: Dim ;
214+ type StrideType = Contiguous ;
215+
216+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
217+ ( self . 0 . into_dimension ( ) , self . 1 )
218+ }
219+
220+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < Contiguous > ) {
221+ unimplemented ! ( )
222+ }
223+ }
224+
225+ /// Custom strides
226+ #[ derive( Copy , Clone , Debug ) ]
227+ pub struct CustomStrides < D > { strides : D }
228+
229+ // newtype constructor without public field
230+ #[ allow( non_snake_case) ]
231+ pub fn CustomStrides < T > ( strides : T ) -> CustomStrides < T > {
232+ CustomStrides { strides }
233+ }
234+
235+ impl < T > ShapeArg for ( T , CustomStrides < T > ) where T : IntoDimension {
236+ type Dim = T :: Dim ;
237+ type StrideType = T :: Dim ;
238+
239+ fn into_shape_and_order ( self , _default : Order ) -> ( Self :: Dim , Order ) {
240+ ( self . 0 . into_dimension ( ) , _default)
241+ }
242+
243+ fn into_shape_and_strides ( self , _default : Order ) -> ( Self :: Dim , Strides < T :: Dim > ) {
244+ ( self . 0 . into_dimension ( ) , Strides :: Custom ( self . 1 . strides . into_dimension ( ) ) )
245+ }
246+ }
0 commit comments