88use crate :: dimension:: slices_intersect;
99use crate :: error:: { ErrorKind , ShapeError } ;
1010use crate :: { ArrayViewMut , DimAdd , Dimension , Ix0 , Ix1 , Ix2 , Ix3 , Ix4 , Ix5 , Ix6 , IxDyn } ;
11+ use alloc:: vec:: Vec ;
12+ use std:: convert:: TryFrom ;
1113use std:: fmt;
1214use std:: marker:: PhantomData ;
1315use std:: ops:: { Deref , Range , RangeFrom , RangeFull , RangeInclusive , RangeTo , RangeToInclusive } ;
@@ -402,6 +404,24 @@ where
402404 }
403405}
404406
407+ fn check_dims_for_sliceinfo < Din , Dout > ( indices : & [ AxisSliceInfo ] ) -> Result < ( ) , ShapeError >
408+ where
409+ Din : Dimension ,
410+ Dout : Dimension ,
411+ {
412+ if let Some ( in_ndim) = Din :: NDIM {
413+ if in_ndim != indices. in_ndim ( ) {
414+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
415+ }
416+ }
417+ if let Some ( out_ndim) = Dout :: NDIM {
418+ if out_ndim != indices. out_ndim ( ) {
419+ return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
420+ }
421+ }
422+ Ok ( ( ) )
423+ }
424+
405425impl < T , Din , Dout > SliceInfo < T , Din , Dout >
406426where
407427 T : AsRef < [ AxisSliceInfo ] > ,
@@ -424,12 +444,8 @@ where
424444 out_dim : PhantomData < Dout > ,
425445 ) -> SliceInfo < T , Din , Dout > {
426446 if cfg ! ( debug_assertions) {
427- if let Some ( in_ndim) = Din :: NDIM {
428- assert_eq ! ( in_ndim, indices. as_ref( ) . in_ndim( ) ) ;
429- }
430- if let Some ( out_ndim) = Dout :: NDIM {
431- assert_eq ! ( out_ndim, indices. as_ref( ) . out_ndim( ) ) ;
432- }
447+ check_dims_for_sliceinfo :: < Din , Dout > ( indices. as_ref ( ) )
448+ . expect ( "`Din` and `Dout` must be consistent with `indices`." ) ;
433449 }
434450 SliceInfo {
435451 in_dim,
@@ -449,21 +465,14 @@ where
449465 ///
450466 /// Errors if `Din` or `Dout` is not consistent with `indices`.
451467 ///
468+ /// For common types, a safe alternative is to use `TryFrom` instead.
469+ ///
452470 /// # Safety
453471 ///
454472 /// The caller must ensure `indices.as_ref()` always returns the same value
455473 /// when called multiple times.
456474 pub unsafe fn new ( indices : T ) -> Result < SliceInfo < T , Din , Dout > , ShapeError > {
457- if let Some ( in_ndim) = Din :: NDIM {
458- if in_ndim != indices. as_ref ( ) . in_ndim ( ) {
459- return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
460- }
461- }
462- if let Some ( out_ndim) = Dout :: NDIM {
463- if out_ndim != indices. as_ref ( ) . out_ndim ( ) {
464- return Err ( ShapeError :: from_kind ( ErrorKind :: IncompatibleShape ) ) ;
465- }
466- }
475+ check_dims_for_sliceinfo :: < Din , Dout > ( indices. as_ref ( ) ) ?;
467476 Ok ( SliceInfo {
468477 in_dim : PhantomData ,
469478 out_dim : PhantomData ,
@@ -508,6 +517,79 @@ where
508517 }
509518}
510519
520+ impl < ' a , Din , Dout > TryFrom < & ' a [ AxisSliceInfo ] > for & ' a SliceInfo < [ AxisSliceInfo ] , Din , Dout >
521+ where
522+ Din : Dimension ,
523+ Dout : Dimension ,
524+ {
525+ type Error = ShapeError ;
526+
527+ fn try_from (
528+ indices : & ' a [ AxisSliceInfo ] ,
529+ ) -> Result < & ' a SliceInfo < [ AxisSliceInfo ] , Din , Dout > , ShapeError > {
530+ check_dims_for_sliceinfo :: < Din , Dout > ( indices) ?;
531+ unsafe {
532+ // This is okay because we've already checked the correctness of
533+ // `Din` and `Dout`, and the only non-zero-sized member of
534+ // `SliceInfo` is `indices`, so `&SliceInfo<[AxisSliceInfo], Din,
535+ // Dout>` should have the same bitwise representation as
536+ // `&[AxisSliceInfo]`.
537+ Ok ( & * ( indices as * const [ AxisSliceInfo ]
538+ as * const SliceInfo < [ AxisSliceInfo ] , Din , Dout > ) )
539+ }
540+ }
541+ }
542+
543+ impl < Din , Dout > TryFrom < Vec < AxisSliceInfo > > for SliceInfo < Vec < AxisSliceInfo > , Din , Dout >
544+ where
545+ Din : Dimension ,
546+ Dout : Dimension ,
547+ {
548+ type Error = ShapeError ;
549+
550+ fn try_from (
551+ indices : Vec < AxisSliceInfo > ,
552+ ) -> Result < SliceInfo < Vec < AxisSliceInfo > , Din , Dout > , ShapeError > {
553+ unsafe {
554+ // This is okay because `Vec` always returns the same value for
555+ // `.as_ref()`.
556+ Self :: new ( indices)
557+ }
558+ }
559+ }
560+
561+ macro_rules! impl_tryfrom_array_for_sliceinfo {
562+ ( $len: expr) => {
563+ impl <Din , Dout > TryFrom <[ AxisSliceInfo ; $len] >
564+ for SliceInfo <[ AxisSliceInfo ; $len] , Din , Dout >
565+ where
566+ Din : Dimension ,
567+ Dout : Dimension ,
568+ {
569+ type Error = ShapeError ;
570+
571+ fn try_from(
572+ indices: [ AxisSliceInfo ; $len] ,
573+ ) -> Result <SliceInfo <[ AxisSliceInfo ; $len] , Din , Dout >, ShapeError > {
574+ unsafe {
575+ // This is okay because `[AxisSliceInfo; N]` always returns
576+ // the same value for `.as_ref()`.
577+ Self :: new( indices)
578+ }
579+ }
580+ }
581+ } ;
582+ }
583+ impl_tryfrom_array_for_sliceinfo ! ( 0 ) ;
584+ impl_tryfrom_array_for_sliceinfo ! ( 1 ) ;
585+ impl_tryfrom_array_for_sliceinfo ! ( 2 ) ;
586+ impl_tryfrom_array_for_sliceinfo ! ( 3 ) ;
587+ impl_tryfrom_array_for_sliceinfo ! ( 4 ) ;
588+ impl_tryfrom_array_for_sliceinfo ! ( 5 ) ;
589+ impl_tryfrom_array_for_sliceinfo ! ( 6 ) ;
590+ impl_tryfrom_array_for_sliceinfo ! ( 7 ) ;
591+ impl_tryfrom_array_for_sliceinfo ! ( 8 ) ;
592+
511593impl < T , Din , Dout > AsRef < [ AxisSliceInfo ] > for SliceInfo < T , Din , Dout >
512594where
513595 T : AsRef < [ AxisSliceInfo ] > ,
0 commit comments