File tree Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Expand file tree Collapse file tree 2 files changed +10
-2
lines changed Original file line number Diff line number Diff line change 11import torch
2+ from absl import flags
23from .layers import (
34 create_quantized_from_nn_linear ,
45 create_quantized_from_nn_embedding ,
78)
89
910
11+ _QUANTIZE_EMBEDDING = flags .DEFINE_bool (
12+ "internal_quantize_embedding_layer" ,
13+ True ,
14+ "Whether to quantize embedding layer or not. Defaults to true" ,
15+ )
16+
17+
1018def quantize_model (float_model , config ):
1119 """Apply quantization to linear layers."""
1220
@@ -17,7 +25,7 @@ def quantize_nn_mod(float_model):
1725 new_mod = mod .get_quantized_version ()
1826 elif isinstance (mod , torch .nn .Linear ):
1927 new_mod = create_quantized_from_nn_linear (mod , config )
20- elif isinstance (mod , torch .nn .Embedding ):
28+ elif isinstance (mod , torch .nn .Embedding ) and _QUANTIZE_EMBEDDING . value :
2129 new_mod = create_quantized_from_nn_embedding (mod , config )
2230
2331 if new_mod :
Original file line number Diff line number Diff line change @@ -437,7 +437,7 @@ def forward(
437437 hidden_states = self .norm (hidden_states )
438438
439439 embedder_weight = self .embedder .weight
440- if self .env . quant_config . enable_weight_quantization :
440+ if hasattr ( self .embedder , "weight_scaler" ) :
441441 embedder_weight = embedder_weight * self .embedder .weight_scaler
442442 logits = torch .matmul (hidden_states , embedder_weight .t ())
443443 return logits
You can’t perform that action at this time.
0 commit comments