@@ -60,13 +60,41 @@ def shifted_softplus(input):
6060 return F .softplus (input ) - F .softplus (torch .zeros (1 , device = input .device ))
6161
6262
63+ def multi_slice (starts , ends ):
64+ """
65+ Compute the union of indexes in multiple slices.
66+
67+ Example::
68+
69+ >>> mask = multi_slice(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
70+ >>> assert (mask == torch.tensor([0, 1, 2, 4, 5]).all()
71+
72+ Parameters:
73+ starts (LongTensor): start indexes of slices
74+ ends (LongTensor): end indexes of slices
75+ """
76+ values = torch .cat ([torch .ones_like (starts ), - torch .ones_like (ends )])
77+ slices = torch .cat ([starts , ends ])
78+ slices , order = slices .sort ()
79+ values = values [order ]
80+ depth = values .cumsum (0 )
81+ valid = (values == 1 & depth == 0 ) | (values == - 1 & depth == 1 )
82+ slices = slices [valid ]
83+
84+ starts , ends = slices .view (- 1 , 2 ).t ()
85+ size = ends - starts
86+ indexes = variadic_arange (size )
87+ indexes = indexes + starts .repeat_interleave (size )
88+ return indexes
89+
90+
6391def multi_slice_mask (starts , ends , length ):
6492 """
6593 Compute the union of multiple slices into a binary mask.
6694
6795 Example::
6896
69- >>> mask = F. multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
97+ >>> mask = multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
7098 >>> assert (mask == torch.tensor([1, 1, 1, 0, 1, 1])).all()
7199
72100 Parameters:
@@ -75,10 +103,10 @@ def multi_slice_mask(starts, ends, length):
75103 length (int): length of mask
76104 """
77105 values = torch .cat ([torch .ones_like (starts ), - torch .ones_like (ends )])
78- indexes = torch .cat ([starts , ends ])
79- if indexes .numel ():
80- assert indexes .min () >= 0 and indexes .max () <= length
81- mask = scatter_add (values , indexes , dim_size = length + 1 )[:- 1 ]
106+ slices = torch .cat ([starts , ends ])
107+ if slices .numel ():
108+ assert slices .min () >= 0 and slices .max () <= length
109+ mask = scatter_add (values , slices , dim_size = length + 1 )[:- 1 ]
82110 mask = mask .cumsum (0 ).bool ()
83111 return mask
84112
@@ -108,12 +136,8 @@ def _size_to_index(size):
108136 Parameters:
109137 size (LongTensor): size of each sample
110138 """
111- cum_size = size .cumsum (0 )
112- # special case 1: size[-1] = 0
113- index = cum_size [cum_size < cum_size [- 1 ]]
114- # special case 2: size[i] = size[i+1] = 0
115- index2sample = scatter_add (torch .ones_like (index ), index , dim_size = cum_size [- 1 ])
116- index2sample = index2sample .cumsum (0 )
139+ range = torch .arange (len (size ), device = size .device )
140+ index2sample = range .repeat_interleave (size )
117141 return index2sample
118142
119143
@@ -363,7 +387,7 @@ def variadic_sort(input, size, descending=False):
363387 return value , index
364388
365389
366- def variadic_arange (size , device = None ):
390+ def variadic_arange (size ):
367391 """
368392 Return a 1-D tensor that contains integer intervals of variadic sizes.
369393 This is a variadic variant of ``torch.arange(stop).expand(batch_size, -1)``.
@@ -372,17 +396,15 @@ def variadic_arange(size, device=None):
372396
373397 Parameters:
374398 size (LongTensor): size of intervals of shape :math:`(N,)`
375- device (torch.device, optional): device of the tensor
376399 """
377- index2sample = _size_to_index (size )
378400 starts = size .cumsum (0 ) - size
379401
380- range = torch .arange (size .sum (), device = device )
381- range = range - starts [ index2sample ]
402+ range = torch .arange (size .sum (), device = size . device )
403+ range = range - starts . repeat_interleave ( size )
382404 return range
383405
384406
385- def variadic_randperm (size , device = None ):
407+ def variadic_randperm (size ):
386408 """
387409 Return random permutations for sets with variadic sizes.
388410 The ``i``-th permutation contains integers from 0 to ``size[i] - 1``.
@@ -393,11 +415,19 @@ def variadic_randperm(size, device=None):
393415 size (LongTensor): size of sets of shape :math:`(N,)`
394416 device (torch.device, optional): device of the tensor
395417 """
396- rand = torch .rand (size .sum (), device = device )
418+ rand = torch .rand (size .sum (), device = size . device )
397419 perm = variadic_sort (rand , size )[1 ]
398420 return perm
399421
400422
423+ def variadic_sample (input , size , k ):
424+ rand = torch .rand (len (size ), k , device = size .device )
425+ index = (rand * size .unsqueeze (- 1 )).long ()
426+ index = index + (size .cumsum (0 ) - size ).unsqueeze (- 1 )
427+ sample = input [index ]
428+ return sample
429+
430+
401431def one_hot (index , size ):
402432 """
403433 Expand indexes into one-hot vectors.
0 commit comments