@@ -9,44 +9,33 @@ public static partial class torchvision
99 {
1010 internal class Normalize : ITransform , IDisposable
1111 {
12- internal Normalize ( double [ ] means , double [ ] stdevs , ScalarType dtype = ScalarType . Float32 , torch . Device ? device = null )
12+ internal Normalize ( double [ ] means , double [ ] stdevs , bool inplace = false )
1313 {
1414 if ( means is null ) throw new ArgumentNullException ( nameof ( means ) ) ;
1515 if ( stdevs is null ) throw new ArgumentNullException ( nameof ( stdevs ) ) ;
1616 if ( means . Length != stdevs . Length )
1717 throw new ArgumentException ( $ "{ nameof ( means ) } and { nameof ( stdevs ) } must be the same length in call to Normalize") ;
1818 if ( means . Length != 1 && means . Length != 3 )
1919 throw new ArgumentException ( $ "Since they correspond to the number of channels in an image, { nameof ( means ) } and { nameof ( stdevs ) } must both be either 1 or 3 long") ;
20+ this . means = means ;
21+ this . stdevs = stdevs ;
22+ this . inplace = inplace ;
2023
21- this . means = means . ToTensor ( new long [ ] { 1 , means . Length , 1 , 1 } ) ; // Assumes NxCxHxW
22- this . stdevs = stdevs . ToTensor ( new long [ ] { 1 , stdevs . Length , 1 , 1 } ) ; // Assumes NxCxHxW
23-
24- if ( dtype != ScalarType . Float64 ) {
25- this . means = this . means . to_type ( dtype ) ;
26- this . stdevs = this . stdevs . to_type ( dtype ) ;
27- }
28-
29- if ( device != null && device . type != DeviceType . CPU ) {
30- this . means = this . means . to ( device ) ;
31- this . stdevs = this . stdevs . to ( device ) ;
32- }
3324 }
3425
3526 public Tensor call ( Tensor input )
3627 {
37- if ( means . size ( 1 ) != input . size ( 1 ) ) throw new ArgumentException ( "The number of channels is not equal to the number of means and standard deviations" ) ;
38- return ( input - means ) / stdevs ;
28+ return transforms . functional . normalize ( input , means , stdevs , inplace ) ;
3929 }
4030
41- private Tensor means ;
42- private Tensor stdevs ;
31+ private readonly double [ ] means ;
32+ private readonly double [ ] stdevs ;
33+ private readonly bool inplace ;
4334 bool disposedValue ;
4435
4536 protected virtual void Dispose ( bool disposing )
4637 {
4738 if ( ! disposedValue ) {
48- means ? . Dispose ( ) ;
49- stdevs ? . Dispose ( ) ;
5039 disposedValue = true ;
5140 }
5241 }
@@ -72,12 +61,11 @@ public static partial class transforms
7261 /// </summary>
7362 /// <param name="means">Sequence of means for each channel.</param>
7463 /// <param name="stdevs">Sequence of standard deviations for each channel.</param>
75- /// <param name="dtype">Bool to make this operation inplace.</param>
76- /// <param name="device">The device to place the output tensor on.</param>
64+ /// <param name="inplace">Bool to make this operation inplace.</param>
7765 /// <returns></returns>
78- static public ITransform Normalize ( double [ ] means , double [ ] stdevs , ScalarType dtype = ScalarType . Float32 , torch . Device ? device = null )
66+ static public ITransform Normalize ( double [ ] means , double [ ] stdevs , bool inplace = false )
7967 {
80- return new Normalize ( means , stdevs , dtype , device ) ;
68+ return new Normalize ( means , stdevs , inplace ) ;
8169 }
8270 }
8371 }
0 commit comments