File tree Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Expand file tree Collapse file tree 1 file changed +3
-5
lines changed Original file line number Diff line number Diff line change @@ -532,17 +532,15 @@ def forward(
532532 if self .affine_param :
533533 self .update_affine (flatten , self .embed , mask = mask )
534534
535- # affine params
535+ # get maybe learnable codes
536+ embed = self .embed if self .learnable_codebook else self .embed .detach ()
536537
538+ # affine params
537539 if self .affine_param :
538540 codebook_std = self .codebook_variance .clamp (min = 1e-5 ).sqrt ()
539541 batch_std = self .batch_variance .clamp (min = 1e-5 ).sqrt ()
540542 embed = (embed - self .codebook_mean ) * (batch_std / codebook_std ) + self .batch_mean
541543
542- # get maybe learnable codes
543-
544- embed = self .embed if self .learnable_codebook else self .embed .detach ()
545-
546544 # handle maybe implicit neural codebook
547545 # and calculate distance
548546
You can’t perform that action at this time.
0 commit comments