Skip to content

Commit 1355018

Browse files
committed
add doc for variadic_sample
1 parent e30e2b0 commit 1355018

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

doc/source/notes/variadic.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ Naturally, the prediction over nodes also forms a variadic tensor with ``num_nod
111111
:func:`variadic_arange <torchdrug.layers.functional.variadic_arange>`,
112112
:func:`variadic_sort <torchdrug.layers.functional.variadic_sort>`,
113113
:func:`variadic_topk <torchdrug.layers.functional.variadic_topk>`,
114+
:func:`variadic_randperm <torchdrug.layers.functional.variadic_randperm>`,
115+
:func:`variadic_sample <torchdrug.layers.functional.variadic_sample>`,
116+
:func:`variadic_softmax <torchdrug.layers.functional.variadic_softmax>`,
114117
:func:`variadic_log_softmax <torchdrug.layers.functional.variadic_log_softmax>`,
115118
:func:`variadic_cross_entropy <torchdrug.layers.functional.variadic_cross_entropy>`,
116119
:func:`variadic_accuracy <torchdrug.metrics.variadic_accuracy>`

torchdrug/layers/functional/functional.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,6 @@ def variadic_sum(input, size):
185185
Parameters:
186186
input (Tensor): input of shape :math:`(B, ...)`
187187
size (LongTensor): size of sets of shape :math:`(N,)`
188-
189-
Returns
190-
Tensor: sum
191188
"""
192189
index2sample = _size_to_index(size)
193190
index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
@@ -206,9 +203,6 @@ def variadic_mean(input, size):
206203
Parameters:
207204
input (Tensor): input of shape :math:`(B, ...)`
208205
size (LongTensor): size of sets of shape :math:`(N,)`
209-
210-
Returns
211-
Tensor: mean
212206
"""
213207
index2sample = _size_to_index(size)
214208
index2sample = index2sample.view([-1] + [1] * (input.ndim - 1))
@@ -420,8 +414,18 @@ def variadic_randperm(size):
420414
return perm
421415

422416

423-
def variadic_sample(input, size, k):
424-
rand = torch.rand(len(size), k, device=size.device)
417+
def variadic_sample(input, size, num_sample):
418+
"""
419+
Draw samples with replacement from sets with variadic sizes.
420+
421+
Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
422+
423+
Parameters:
424+
input (Tensor): input of shape :math:`(B, ...)`
425+
size (LongTensor): size of sets of shape :math:`(N,)`
426+
num_sample (int): number of samples to draw from each set
427+
"""
428+
rand = torch.rand(len(size), num_sample, device=size.device)
425429
index = (rand * size.unsqueeze(-1)).long()
426430
index = index + (size.cumsum(0) - size).unsqueeze(-1)
427431
sample = input[index]

0 commit comments

Comments
 (0)