@@ -76,19 +76,6 @@ def lens_to_mask(lens, max_length):
7676 seq = torch .arange (max_length , device = lens .device )
7777 return seq < lens [:, None ]
7878
79- def efficient_rotation_trick_transform (u , q , e ):
80- """
81- 4.2 in https://arxiv.org/abs/2410.06424
82- """
83- e = rearrange (e , 'b d -> b 1 d' )
84- w = l2norm (u + q , dim = 1 ).detach ()
85-
86- return (
87- e -
88- 2 * (e @ rearrange (w , 'b d -> b d 1' ) @ rearrange (w , 'b d -> b 1 d' )) +
89- 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
90- )
91-
9279def uniform_init (* shape ):
9380 t = torch .empty (shape )
9481 nn .init .kaiming_uniform_ (t )
@@ -248,6 +235,39 @@ def kmeans(
248235
249236 return means , bins
250237
238+ # rotation trick related
239+
240+ def efficient_rotation_trick_transform (u , q , e ):
241+ """
242+ 4.2 in https://arxiv.org/abs/2410.06424
243+ """
244+ e = rearrange (e , 'b d -> b 1 d' )
245+ w = l2norm (u + q , dim = 1 ).detach ()
246+
247+ return (
248+ e -
249+ 2 * (e @ rearrange (w , 'b d -> b d 1' ) @ rearrange (w , 'b d -> b 1 d' )) +
250+ 2 * (e @ rearrange (u , 'b d -> b d 1' ).detach () @ rearrange (q , 'b d -> b 1 d' ).detach ())
251+ )
252+
253+ def rotate_from_to (src , tgt ):
254+ # rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
255+ tgt , inverse = pack_one (tgt , '* d' )
256+ src , _ = pack_one (src , '* d' )
257+
258+ norm_tgt = tgt .norm (dim = - 1 , keepdim = True )
259+ norm_src = src .norm (dim = - 1 , keepdim = True )
260+
261+ rotated_src = efficient_rotation_trick_transform (
262+ safe_div (tgt , norm_tgt ),
263+ safe_div (src , norm_src ),
264+ tgt
265+ ).squeeze ()
266+
267+ rotated = rotated_src * safe_div (norm_src , norm_tgt ).detach ()
268+
269+ return inverse (rotated )
270+
251271# distributed helpers
252272
253273@cache
@@ -1098,22 +1118,7 @@ def forward(
10981118 commit_quantize = maybe_detach (quantize )
10991119
11001120 if self .rotation_trick :
1101- # rotation trick STE (https://arxiv.org/abs/2410.06424) to get gradients through VQ layer.
1102- x , inverse = pack_one (x , '* d' )
1103- quantize , _ = pack_one (quantize , '* d' )
1104-
1105- norm_x = x .norm (dim = - 1 , keepdim = True )
1106- norm_quantize = quantize .norm (dim = - 1 , keepdim = True )
1107-
1108- rot_quantize = efficient_rotation_trick_transform (
1109- safe_div (x , norm_x ),
1110- safe_div (quantize , norm_quantize ),
1111- x
1112- ).squeeze ()
1113-
1114- quantize = rot_quantize * safe_div (norm_quantize , norm_x ).detach ()
1115-
1116- x , quantize = inverse (x ), inverse (quantize )
1121+ quantize = rotate_from_to (quantize , x )
11171122 else :
11181123 # standard STE to get gradients through VQ layer.
11191124 quantize = x + (quantize - x ).detach ()
0 commit comments