@@ -469,7 +469,7 @@ def _apply_learned_naflex_pos_embed_grid_sample(
469469 Based on proposal by https://github.com/stas-sl
470470 """
471471 device = x .device
472- B , C = x .shape [ 0 : 2 ]
472+ B , N , C = x .shape
473473
474474 def _make_coords (h , w ):
475475 _y , _x = torch .meshgrid (
@@ -480,10 +480,12 @@ def _make_coords(h, w):
480480 coord = torch .stack ([_y .flatten (), _x .flatten ()], dim = 1 )
481481 return coord
482482
483- coords = pad_sequence (
484- [_make_coords (h , w ) for h , w in naflex_grid_sizes ],
485- batch_first = True ,
486- )
483+ coords = torch .zeros (B , N , 2 , dtype = torch .long , device = device )
484+ for i , (h , w ) in enumerate (naflex_grid_sizes ):
485+ coords_i = _make_coords (h , w ) # (h*w, 2)
486+ coords [i , :coords_i .shape [0 ]] = coords_i # pad with zeros past h*w
487+ # FIXME should we be masking?
488+
487489 shapes = coords .amax (1 ) + 1
488490 theta = torch .zeros (B , 2 , 3 , dtype = torch .float32 , device = device )
489491 if self .pos_embed_ar_preserving :
@@ -506,7 +508,7 @@ def _make_coords(h, w):
506508 align_corners = False ,
507509 padding_mode = 'border' ,
508510 ).to (dtype = x .dtype )
509- bi = torch .arange (B , device = device ).unsqueeze (1 ). expand ( - 1 , coords . shape [ 1 ])
511+ bi = torch .arange (B , device = device ).unsqueeze (1 )
510512 # NOTE leave as '+=', do not change to .add_(...)
511513 x += pos_embed [bi , :, coords [..., 0 ], coords [..., 1 ]]
512514
0 commit comments