@@ -27,6 +27,8 @@ vq = VectorQuantize(
2727
2828x = torch.randn(1 , 1024 , 256 )
2929quantized, indices, commit_loss = vq(x) # (1, 1024, 256), (1, 1024), (1)
30+ print (quantized.shape, indices.shape, commit_loss.shape)
31+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
3032```
3133
3234## Residual VQ
@@ -46,16 +48,14 @@ residual_vq = ResidualVQ(
4648x = torch.randn(1 , 1024 , 256 )
4749
4850quantized, indices, commit_loss = residual_vq(x)
49-
50- # (1, 1024, 256), (1, 1024, 8), (1, 8)
51- # (batch, seq, dim), (batch, seq, quantizer), (batch, quantizer)
51+ print (quantized.shape, indices.shape, commit_loss.shape)
52+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
5253
5354# if you need all the codes across the quantization layers, just pass return_all_codes = True
5455
5556quantized, indices, commit_loss, all_codes = residual_vq(x, return_all_codes = True )
56-
57- # *_, (8, 1, 1024, 256)
58- # all_codes - (quantizer, batch, seq, dim)
57+ print (all_codes.shape)
58+ # > torch.Size([8, 1, 1024, 256])
5959```
6060
6161Furthermore, <a href =" https://arxiv.org/abs/2203.01941 " >this paper</a > uses Residual-VQ to construct the RQ-VAE, for generating high resolution images with more compressed codes.
@@ -77,9 +77,8 @@ residual_vq = ResidualVQ(
7777
7878x = torch.randn(1 , 1024 , 256 )
7979quantized, indices, commit_loss = residual_vq(x)
80-
81- # (1, 1024, 256), (8, 1, 1024), (8, 1)
82- # (batch, seq, dim), (quantizer, batch, seq), (quantizer, batch)
80+ print (quantized.shape, indices.shape, commit_loss.shape)
81+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([1, 8])
8382```
8483
8584<a href =" https://arxiv.org/abs/2305.02765 " >A recent paper</a > further proposes to do residual VQ on groups of the feature dimension, showing equivalent results to Encodec while using far fewer codebooks. You can use it by importing ` GroupedResidualVQ `
@@ -98,9 +97,8 @@ residual_vq = GroupedResidualVQ(
9897x = torch.randn(1 , 1024 , 256 )
9998
10099quantized, indices, commit_loss = residual_vq(x)
101-
102- # (1, 1024, 256), (2, 1, 1024, 8), (2, 1, 8)
103- # (batch, seq, dim), (groups, batch, seq, quantizer), (groups, batch, quantizer)
100+ print (quantized.shape, indices.shape, commit_loss.shape)
101+ # > torch.Size([1, 1024, 256]) torch.Size([2, 1, 1024, 8]) torch.Size([2, 1, 8])
104102
105103```
106104
@@ -122,6 +120,8 @@ residual_vq = ResidualVQ(
122120
123121x = torch.randn(1 , 1024 , 256 )
124122quantized, indices, commit_loss = residual_vq(x)
123+ print (quantized.shape, indices.shape, commit_loss.shape)
124+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 4]) torch.Size([1, 4])
125125```
126126
127127## Increasing codebook usage
@@ -144,6 +144,8 @@ vq = VectorQuantize(
144144
145145x = torch.randn(1 , 1024 , 256 )
146146quantized, indices, commit_loss = vq(x)
147+ print (quantized.shape, indices.shape, commit_loss.shape)
148+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
147149```
148150
149151### Cosine similarity
@@ -162,6 +164,8 @@ vq = VectorQuantize(
162164
163165x = torch.randn(1 , 1024 , 256 )
164166quantized, indices, commit_loss = vq(x)
167+ print (quantized.shape, indices.shape, commit_loss.shape)
168+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
165169```
166170
167171### Expiring stale codes
@@ -180,6 +184,8 @@ vq = VectorQuantize(
180184
181185x = torch.randn(1 , 1024 , 256 )
182186quantized, indices, commit_loss = vq(x)
187+ print (quantized.shape, indices.shape, commit_loss.shape)
188+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024]) torch.Size([1])
183189```
184190
185191### Orthogonal regularization loss
@@ -204,6 +210,8 @@ vq = VectorQuantize(
204210img_fmap = torch.randn(1 , 256 , 32 , 32 )
205211quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32), (1,)
206212# loss now contains the orthogonal regularization loss with the weight as assigned
213+ print (quantized.shape, indices.shape, loss.shape)
214+ # > torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32]) torch.Size([1])
207215```
208216
209217### Multi-headed VQ
@@ -226,10 +234,12 @@ vq = VectorQuantize(
226234)
227235
228236img_fmap = torch.randn(1 , 256 , 32 , 32 )
229- quantized, indices, loss = vq(img_fmap) # (1, 256, 32, 32), (1, 32, 32, 8), (1,)
237+ quantized, indices, loss = vq(img_fmap)
238+ print (quantized.shape, indices.shape, loss.shape)
239+ # > torch.Size([1, 256, 32, 32]) torch.Size([1, 32, 32, 8]) torch.Size([1])
230240
231- # indices shape - (batch, height, width, heads)
232241```
242+
233243### Random Projection Quantizer
234244
235245<a href =" https://arxiv.org/abs/2202.01855 " >This paper</a > first proposed to use a random projection quantizer for masked speech modeling, where signals are projected with a randomly initialized matrix and then matched with a random initialized codebook. One therefore does not need to learn the quantizer. This technique was used by Google's <a href =" https://ai.googleblog.com/2023/03/universal-speech-model-usm-state-of-art.html " >Universal Speech Model</a > to achieve SOTA for speech-to-text modeling.
@@ -248,7 +258,9 @@ quantizer = RandomProjectionQuantizer(
248258)
249259
250260x = torch.randn(1 , 1024 , 512 )
251- indices = quantizer(x) # (1, 1024, 16) - (batch, seq, num_codebooks)
261+ indices = quantizer(x)
262+ print (indices.shape)
263+ # > torch.Size([1, 1024, 16])
252264```
253265
254266This repository should also automatically synchronizing the codebooks in a multi-process setting. If somehow it isn't, please open an issue. You can override whether to synchronize codebooks or not by setting ` sync_codebook = True | False `
@@ -279,10 +291,11 @@ quantizer = FSQ(levels)
279291x = torch.randn(1 , 1024 , 4 ) # 4 since there are 4 levels
280292xhat, indices = quantizer(x)
281293
282- print (xhat.shape) # (1, 1024, 4) - (batch, seq, dim)
283- print (indices.shape) # (1, 1024) - (batch, seq)
294+ print (xhat.shape)
295+ # > torch.Size([1, 1024, 4])
296+ print (indices.shape)
297+ # > torch.Size([1, 1024])
284298
285- assert xhat.shape == x.shape
286299assert torch.all(xhat == quantizer.indices_to_codes(indices))
287300```
288301
@@ -305,14 +318,12 @@ x = torch.randn(1, 1024, 256)
305318residual_fsq.eval()
306319
307320quantized, indices = residual_fsq(x)
308-
309- # (1, 1024, 256), (1, 1024, 8), (8)
310- # (batch, seq, dim), (batch, seq, quantizers), (quantizers)
321+ print (quantized.shape, indices.shape)
322+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8])
311323
312324quantized_out = residual_fsq.get_output_from_indices(indices)
313-
314- # (8, 1, 1024, 8)
315- # (residual layers, batch, seq, quantizers)
325+ print (quantized_out.shape)
326+ # > torch.Size([1, 1024, 256])
316327
317328assert torch.all(quantized == quantized_out)
318329```
@@ -346,26 +357,34 @@ quantizer = LFQ(
346357image_feats = torch.randn(1 , 16 , 32 , 32 )
347358
348359quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature = 100 .) # you may want to experiment with temperature
360+ print (quantized.shape, indices.shape, entropy_aux_loss.shape)
361+ # > torch.Size([1, 16, 32, 32]) torch.Size([1, 32, 32]) torch.Size([])
349362
350- # (1, 16, 32, 32), (1, 32, 32), (1,)
351-
352- assert image_feats.shape == quantized.shape
353363assert (quantized == quantizer.indices_to_codes(indices)).all()
354364```
355365
356366You can also pass in video features as ` (batch, feat, time, height, width) ` or sequences as ` (batch, seq, feat) `
357367
358368``` python
369+ import torch
370+ from vector_quantize_pytorch import LFQ
371+
372+ quantizer = LFQ(
373+ codebook_size = 65536 ,
374+ dim = 16 ,
375+ entropy_loss_weight = 0.1 ,
376+ diversity_gamma = 1 .
377+ )
359378
360379seq = torch.randn(1 , 32 , 16 )
361380quantized, * _ = quantizer(seq)
362381
363- assert seq.shape == quantized.shape
382+ # assert seq.shape == quantized.shape
364383
365- video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
366- quantized, * _ = quantizer(video_feats)
384+ # video_feats = torch.randn(1, 16, 10, 32, 32)
385+ # quantized, *_ = quantizer(video_feats)
367386
368- assert video_feats.shape == quantized.shape
387+ # assert video_feats.shape == quantized.shape
369388
370389```
371390
@@ -384,8 +403,8 @@ quantizer = LFQ(
384403image_feats = torch.randn(1 , 16 , 32 , 32 )
385404
386405quantized, indices, entropy_aux_loss = quantizer(image_feats)
387-
388- # ( 1, 16, 32, 32), ( 1, 32, 32, 4), (1, )
406+ print (quantized.shape, indices.shape, entropy_aux_loss.shape)
407+ # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32, 4]) torch.Size([] )
389408
390409assert image_feats.shape == quantized.shape
391410assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -408,14 +427,12 @@ x = torch.randn(1, 1024, 256)
408427residual_lfq.eval()
409428
410429quantized, indices, commit_loss = residual_lfq(x)
411-
412- # (1, 1024, 256), (1, 1024, 8), (8)
413- # (batch, seq, dim), (batch, seq, quantizers), (quantizers)
430+ print (quantized.shape, indices.shape, commit_loss.shape)
431+ # > torch.Size([1, 1024, 256]) torch.Size([1, 1024, 8]) torch.Size([8])
414432
415433quantized_out = residual_lfq.get_output_from_indices(indices)
416-
417- # (8, 1, 1024, 8)
418- # (residual layers, batch, seq, quantizers)
434+ print (quantized_out.shape)
435+ # > torch.Size([1, 1024, 256])
419436
420437assert torch.all(quantized == quantized_out)
421438```
@@ -443,8 +460,8 @@ quantizer = LatentQuantize(
443460image_feats = torch.randn(1 , 16 , 32 , 32 )
444461
445462quantized, indices, loss = quantizer(image_feats)
446-
447- # ( 1, 16, 32, 32), ( 1, 32, 32), (1, )
463+ print (quantized.shape, indices.shape, loss.shape)
464+ # > torch.Size([ 1, 16, 32, 32]) torch.Size([ 1, 32, 32]) torch.Size([] )
448465
449466assert image_feats.shape == quantized.shape
450467assert (quantized == quantizer.indices_to_codes(indices)).all()
@@ -454,15 +471,25 @@ You can also pass in video features as `(batch, feat, time, height, width)` or s
454471
455472``` python
456473
474+ import torch
475+ from vector_quantize_pytorch import LatentQuantize
476+
477+ quantizer = LatentQuantize(
478+ levels = [5 , 5 , 8 ],
479+ dim = 16 ,
480+ commitment_loss_weight = 0.1 ,
481+ quantization_loss_weight = 0.1 ,
482+ )
483+
457484seq = torch.randn(1 , 32 , 16 )
458485quantized, * _ = quantizer(seq)
459-
460- assert seq.shape == quantized.shape
486+ print (quantized.shape)
487+ # > torch.Size([1, 32, 16])
461488
462489video_feats = torch.randn(1 , 16 , 10 , 32 , 32 )
463490quantized, * _ = quantizer(video_feats)
464-
465- assert video_feats.shape == quantized.shape
491+ print (quantized.shape)
492+ # > torch.Size([1, 16, 10, 32, 32])
466493
467494```
468495
@@ -480,6 +507,8 @@ model = LatentQuantize(levels, dim, num_codebooks=num_codebooks)
480507
481508input_tensor = torch.randn(2 , 3 , dim)
482509output_tensor, indices, loss = model(input_tensor)
510+ print (output_tensor.shape, indices.shape, loss.shape)
511+ # > torch.Size([2, 3, 9]) torch.Size([2, 3, 3]) torch.Size([])
483512
484513assert output_tensor.shape == input_tensor.shape
485514assert indices.shape == (2 , 3 , num_codebooks)
0 commit comments