@@ -4,7 +4,9 @@ mod activation_mode;
44pub use activation_descriptor:: * ;
55pub use activation_mode:: * ;
66
7- use crate :: { private, sys, CudnnContext , CudnnError , DataType , IntoResult , TensorDescriptor } ;
7+ use crate :: {
8+ private, sys, CudnnContext , CudnnError , DataType , IntoResult , ScalingDataType , TensorDescriptor ,
9+ } ;
810use cust:: memory:: GpuBuffer ;
911use std:: mem:: MaybeUninit ;
1012
@@ -49,11 +51,11 @@ impl CudnnContext {
4951 ///
5052 /// let desc = ActivationDescriptor::new(mode, nan_opt, coefficient)?;
5153 ///
52- /// let alpha: f32 = 1.0;
54+ /// let alpha = 1.0;
5355 /// let x_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
5456 /// let x = DeviceBuffer::<i8>::from_slice(&[10, 10, 10, 10, 10])?;
5557 ///
56- /// let beta: f32 = 0.0;
58+ /// let beta = 0.0;
5759 /// let y_desc = TensorDescriptor::<i8>::new_strides(&[1, 1, 1, 5], &[5, 5, 5, 1])?;
5860 /// let mut y = DeviceBuffer::<i8>::from_slice(&[0, 0, 0, 0, 0])?;
5961 ///
@@ -76,7 +78,7 @@ impl CudnnContext {
7678 y : & mut impl GpuBuffer < T > ,
7779 ) -> Result < ( ) , CudnnError >
7880 where
79- CompT : SupportedActFwd < T > ,
81+ CompT : ScalingDataType < T > ,
8082 T : DataType ,
8183 {
8284 let alpha_ptr = & alpha as * const CompT as * const _ ;
@@ -179,27 +181,6 @@ impl CudnnContext {
179181 }
180182}
181183
182- /// Supported data type configurations for the activation forward operation.
183- pub trait SupportedActFwd < T > : DataType + private:: Sealed
184- where
185- T : DataType ,
186- {
187- }
188-
189- impl SupportedActFwd < i8 > for f32 { }
190- impl SupportedActFwd < u8 > for f32 { }
191- impl SupportedActFwd < i32 > for f32 { }
192- impl SupportedActFwd < i64 > for f32 { }
193- impl SupportedActFwd < f32 > for f32 { }
194- impl SupportedActFwd < f64 > for f32 { }
195-
196- impl SupportedActFwd < i8 > for f64 { }
197- impl SupportedActFwd < u8 > for f64 { }
198- impl SupportedActFwd < i32 > for f64 { }
199- impl SupportedActFwd < i64 > for f64 { }
200- impl SupportedActFwd < f32 > for f64 { }
201- impl SupportedActFwd < f64 > for f64 { }
202-
203184/// Supported type configurations for the activation backward operation.
204185pub trait SupportedActBwd < T > : DataType + private:: Sealed
205186where
0 commit comments