File tree Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Original file line number Diff line number Diff line change 11[project ]
22name = " vector-quantize-pytorch"
3- version = " 1.18.2 "
3+ version = " 1.18.3 "
44description = " Vector Quantization - Pytorch"
55authors = [
66 { name = " Phil Wang" , email = " lucidrains@gmail.com" }
Original file line number Diff line number Diff line change @@ -103,6 +103,7 @@ def __init__(
103103 commitment_loss_weight = 0. ,
104104 diversity_gamma = 1. ,
105105 straight_through_activation = nn .Identity (),
106+ scale_trick = False , # @cfifty Fifty et al. https://arxiv.org/abs/2410.06424
106107 num_codebooks = 1 ,
107108 keep_num_codebooks_dim = None ,
108109 codebook_scale = 1. , # for residual LFQ, codebook scaled down by 2x at each layer
@@ -160,6 +161,9 @@ def __init__(
160161
161162 self .activation = straight_through_activation
162163
164+ assert not (scale_trick and spherical )
165+ self .scale_trick = scale_trick
166+
163167 # whether to use BSQ (binary spherical quantization)
164168
165169 self .spherical = spherical
@@ -322,7 +326,12 @@ def forward(
322326
323327 if self .training :
324328 x = self .activation (x )
325- x = x + (quantized - x ).detach ()
329+
330+ if self .scale_trick :
331+ x = x * (quantized / x ).detach ()
332+ else :
333+ x = x + (quantized - x ).detach ()
334+
326335 else :
327336 x = quantized
328337
You can’t perform that action at this time.
0 commit comments