@@ -106,7 +106,7 @@ def multi_slice_mask(starts, ends, length):
106106 slices = torch .cat ([starts , ends ])
107107 if slices .numel ():
108108 assert slices .min () >= 0 and slices .max () <= length
109- mask = scatter_add (values , slices , dim_size = length + 1 )[:- 1 ]
109+ mask = scatter_add (values , slices , dim = 0 , dim_size = length + 1 )[:- 1 ]
110110 mask = mask .cumsum (0 ).bool ()
111111 return mask
112112
@@ -230,7 +230,7 @@ def variadic_max(input, size):
230230 index2sample = index2sample .expand_as (input )
231231
232232 value , index = scatter_max (input , index2sample , dim = 0 )
233- index = index - size .cumsum (0 ) + size
233+ index = index + ( size - size .cumsum (0 )). view ([ - 1 ] + [ 1 ] * ( index . ndim - 1 ))
234234 return value , index
235235
236236
@@ -314,7 +314,8 @@ def variadic_topk(input, size, k, largest=True):
314314 Parameters:
315315 input (Tensor): input of shape :math:`(B, ...)`
316316 size (LongTensor): size of sets of shape :math:`(N,)`
317- k (int): the k in "top-k"
317+ k (int or LongTensor): the k in "top-k". Can be a fixed value for all sets,
318+ or different values for different sets of shape :math:`(N,)`.
318319 largest (bool, optional): return largest or smallest elements
319320
320321 Returns
@@ -326,13 +327,19 @@ def variadic_topk(input, size, k, largest=True):
326327 mask = ~ torch .isinf (input )
327328 max = input [mask ].max ().item ()
328329 min = input [mask ].min ().item ()
329- safe_input = input .clamp (2 * min - max , 2 * max - min )
330- offset = (max - min ) * 4
330+ abs_max = input [mask ].abs ().max ().item ()
331+ # special case: max = min
332+ gap = max - min + abs_max * 1e-6
333+ safe_input = input .clamp (min - gap , max + gap )
334+ offset = gap * 4
331335 if largest :
332336 offset = - offset
333337 input_ext = safe_input + offset * index2graph
334338 index_ext = input_ext .argsort (dim = 0 , descending = largest )
335- num_actual = size .clamp (max = k )
339+ if isinstance (k , torch .Tensor ) and k .shape == size .shape :
340+ num_actual = torch .min (size , k )
341+ else :
342+ num_actual = size .clamp (max = k )
336343 num_padding = k - num_actual
337344 starts = size .cumsum (0 ) - size
338345 ends = starts + num_actual
@@ -346,9 +353,14 @@ def variadic_topk(input, size, k, largest=True):
346353
347354 index = index_ext [mask ] # (N * k, ...)
348355 value = input .gather (0 , index )
349- value = value .view (- 1 , k , * input .shape [1 :])
350- index = index .view (- 1 , k , * input .shape [1 :])
351- index = index - (size .cumsum (0 ) - size ).view ([- 1 ] + [1 ] * (index .ndim - 1 ))
356+ if isinstance (k , torch .Tensor ) and k .shape == size .shape :
357+ value = value .view (- 1 , * input .shape [1 :])
358+ index = index .view (- 1 , * input .shape [1 :])
359+ index = index - (size .cumsum (0 ) - size ).repeat_interleave (k ).view ([- 1 ] + [1 ] * (index .ndim - 1 ))
360+ else :
361+ value = value .view (- 1 , k , * input .shape [1 :])
362+ index = index .view (- 1 , k , * input .shape [1 :])
363+ index = index - (size .cumsum (0 ) - size ).view ([- 1 ] + [1 ] * (index .ndim - 1 ))
352364
353365 return value , index
354366
@@ -432,6 +444,39 @@ def variadic_sample(input, size, num_sample):
432444 return sample
433445
434446
447+ def variadic_meshgrid (input1 , size1 , input2 , size2 ):
448+ grid_size = size1 * size2
449+ local_index = variadic_arange (grid_size )
450+ local_inner_size = size2 .repeat_interleave (grid_size )
451+ offset1 = (size1 .cumsum (0 ) - size1 ).repeat_interleave (grid_size )
452+ offset2 = (size2 .cumsum (0 ) - size2 ).repeat_interleave (grid_size )
453+ index1 = local_index // local_inner_size + offset1
454+ index2 = local_index % local_inner_size + offset2
455+ return input1 [index1 ], input2 [index2 ]
456+
457+
458+ def variadic_to_padded (input , size , value = 0 ):
459+ num_sample = len (size )
460+ max_size = size .max ()
461+ starts = torch .arange (num_sample , device = size .device ) * max_size
462+ ends = starts + size
463+ mask = multi_slice_mask (starts , ends , num_sample * max_size )
464+ mask = mask .view (num_sample , max_size )
465+ shape = (num_sample , max_size ) + input .shape [1 :]
466+ padded = torch .full (shape , value , dtype = input .dtype , device = size .device )
467+ padded [mask ] = input
468+ return padded , mask
469+
470+
471+ def padded_to_variadic (padded , size ):
472+ num_sample , max_size = padded .shape [:2 ]
473+ starts = torch .arange (num_sample , device = size .device ) * max_size
474+ ends = starts + size
475+ mask = multi_slice_mask (starts , ends , num_sample * max_size )
476+ mask = mask .view (num_sample , max_size )
477+ return padded [mask ]
478+
479+
435480def one_hot (index , size ):
436481 """
437482 Expand indexes into one-hot vectors.
0 commit comments