Skip to content

Commit 0a58893

Browse files
authored
Testing functional module (#50)
* Update functional tests * Update docs styles * Update level and circular examples * Add type bounds test for bundle and permute * Add encoding tests * Add utility tests * Use utility functions in embeddings
1 parent 0e99946 commit 0a58893

File tree

8 files changed

+1086
-272
lines changed

8 files changed

+1086
-272
lines changed

docs/_static/css/custom.css

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,13 @@ html.writer-html5 .rst-content table.docutils th {
233233
}
234234

235235
html,
236-
body {
237-
background: #ffffff;
236+
body,
237+
.wy-body-for-nav,
238+
.wy-nav-content {
239+
background: #ffffff !important;
238240
}
239241

242+
240243
.wy-nav-content-wrap {
241244
background: none;
242245
}

torchhd/embeddings.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,9 @@ def reset_parameters(self):
151151
self._fill_padding_idx_with_zero()
152152

153153
def forward(self, input: torch.Tensor) -> torch.Tensor:
154-
# tranform the floating point input to an index
155-
# make first variable a copy of the input, then we can reuse the buffer.
156-
# normalized between 0 and 1
157-
normalized = (input - self.low_value) / (self.high_value - self.low_value)
158-
159-
indices = normalized.mul_(self.num_embeddings).floor_()
160-
indices = indices.clamp_(0, self.num_embeddings - 1).long()
154+
indices = functional.value_to_index(
155+
input, self.low_value, self.high_value, self.num_embeddings
156+
).clamp(0, self.num_embeddings - 1)
161157

162158
return super(Level, self).forward(indices)
163159

@@ -219,14 +215,9 @@ def reset_parameters(self):
219215
self._fill_padding_idx_with_zero()
220216

221217
def forward(self, input: torch.Tensor) -> torch.Tensor:
222-
# tranform the floating point input to an index
223-
# make first variable a copy of the input, then we can reuse the buffer.
224-
# normalized between 0 and 1
225-
normalized = (input - self.low_value) / (self.high_value - self.low_value)
226-
normalized.remainder_(1.0)
227-
228-
indices = normalized.mul_(self.num_embeddings).floor_()
229-
indices = indices.clamp_(0, self.num_embeddings - 1).long()
218+
indices = functional.value_to_index(
219+
input, self.low_value, self.high_value, self.num_embeddings
220+
).remainder(self.num_embeddings - 1)
230221

231222
return super(Circular, self).forward(indices)
232223

0 commit comments

Comments
 (0)