@@ -185,9 +185,6 @@ def variadic_sum(input, size):
185185 Parameters:
186186 input (Tensor): input of shape :math:`(B, ...)`
187187 size (LongTensor): size of sets of shape :math:`(N,)`
188-
189- Returns
190- Tensor: sum
191188 """
192189 index2sample = _size_to_index (size )
193190 index2sample = index2sample .view ([- 1 ] + [1 ] * (input .ndim - 1 ))
@@ -206,9 +203,6 @@ def variadic_mean(input, size):
206203 Parameters:
207204 input (Tensor): input of shape :math:`(B, ...)`
208205 size (LongTensor): size of sets of shape :math:`(N,)`
209-
210- Returns
211- Tensor: mean
212206 """
213207 index2sample = _size_to_index (size )
214208 index2sample = index2sample .view ([- 1 ] + [1 ] * (input .ndim - 1 ))
@@ -420,8 +414,18 @@ def variadic_randperm(size):
420414 return perm
421415
422416
423- def variadic_sample (input , size , k ):
424- rand = torch .rand (len (size ), k , device = size .device )
417+ def variadic_sample (input , size , num_sample ):
418+ """
419+ Draw samples with replacement from sets with variadic sizes.
420+
421+ Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`.
422+
423+ Parameters:
424+ input (Tensor): input of shape :math:`(B, ...)`
425+ size (LongTensor): size of sets of shape :math:`(N,)`
426+ num_sample (int): number of samples to draw from each set
427+ """
428+ rand = torch .rand (len (size ), num_sample , device = size .device )
425429 index = (rand * size .unsqueeze (- 1 )).long ()
426430 index = index + (size .cumsum (0 ) - size ).unsqueeze (- 1 )
427431 sample = input [index ]
0 commit comments