@@ -151,13 +151,9 @@ def reset_parameters(self):
151151 self ._fill_padding_idx_with_zero ()
152152
153153 def forward (self , input : torch .Tensor ) -> torch .Tensor :
154- # tranform the floating point input to an index
155- # make first variable a copy of the input, then we can reuse the buffer.
156- # normalized between 0 and 1
157- normalized = (input - self .low_value ) / (self .high_value - self .low_value )
158-
159- indices = normalized .mul_ (self .num_embeddings ).floor_ ()
160- indices = indices .clamp_ (0 , self .num_embeddings - 1 ).long ()
154+ indices = functional .value_to_index (
155+ input , self .low_value , self .high_value , self .num_embeddings
156+ ).clamp (0 , self .num_embeddings - 1 )
161157
162158 return super (Level , self ).forward (indices )
163159
@@ -219,14 +215,9 @@ def reset_parameters(self):
219215 self ._fill_padding_idx_with_zero ()
220216
221217 def forward (self , input : torch .Tensor ) -> torch .Tensor :
222- # tranform the floating point input to an index
223- # make first variable a copy of the input, then we can reuse the buffer.
224- # normalized between 0 and 1
225- normalized = (input - self .low_value ) / (self .high_value - self .low_value )
226- normalized .remainder_ (1.0 )
227-
228- indices = normalized .mul_ (self .num_embeddings ).floor_ ()
229- indices = indices .clamp_ (0 , self .num_embeddings - 1 ).long ()
218+ indices = functional .value_to_index (
219+ input , self .low_value , self .high_value , self .num_embeddings
220+ ).remainder (self .num_embeddings - 1 )
230221
231222 return super (Circular , self ).forward (indices )
232223
0 commit comments