File tree Expand file tree Collapse file tree 4 files changed +15
-14
lines changed Expand file tree Collapse file tree 4 files changed +15
-14
lines changed Original file line number Diff line number Diff line change 3434 "if set, then save the result to the given file name" ,
3535)
3636flags .DEFINE_bool (
37- "internal_use_local_tokenizer" ,
38- 0 ,
39- "Use local tokenizer if set to True"
37+ "internal_use_local_tokenizer" , 0 , "Use local tokenizer if set to True"
4038)
4139
40+
4241def shard_weights (env , weights , weight_shardings ):
4342 """Shard weights according to weight_shardings"""
4443 sharded = {}
Original file line number Diff line number Diff line change 4949flags .DEFINE_bool (
5050 "quantize_kv_cache" , None , "defaults to the same value as quantize_weights"
5151)
52+ flags .DEFINE_multi_string (
53+ "quantize_exclude_layers" ,
54+ None ,
55+ "List of layer names to exclude from quantization" ,
56+ )
5257
5358_VALID_QUANTIZATION_TYPE = {
5459 "int8_per_channel" ,
@@ -178,6 +183,7 @@ def create_quantization_config_from_flags():
178183 config .is_blockwise_weight = "blockwise" in quantize_type
179184
180185 config .enable_activation_quantization = FLAGS .quantize_activation
186+ config .exclude_layers = FLAGS .quantize_exclude_layers
181187 config .enable_kv_quantization = (
182188 FLAGS .quantize_kv_cache
183189 if FLAGS .quantize_kv_cache is not None
Original file line number Diff line number Diff line change 1313# limitations under the License.
1414
1515import dataclasses
16- from typing import Tuple
16+ from typing import List , Tuple , Union
1717
1818import jax
1919import jax .numpy as jnp
@@ -37,6 +37,7 @@ class QuantizationConfig:
3737
3838 enable_activation_quantization : bool = False
3939 enable_kv_quantization : bool = False
40+ exclude_layers : Union [None , List [str ]] = None
4041
4142
4243@dataclasses .dataclass
Original file line number Diff line number Diff line change 11import torch
2- from absl import flags
2+ from . environment import QuantizationConfig
33from .layers import (
44 create_quantized_from_nn_linear ,
55 create_quantized_from_nn_embedding ,
88)
99
1010
11- _QUANTIZE_EMBEDDING = flags .DEFINE_bool (
12- "internal_quantize_embedding_layer" ,
13- True ,
14- "Whether to quantize embedding layer or not. Defaults to true" ,
15- )
16-
17-
18- def quantize_model (float_model , config ):
11+ def quantize_model (float_model , config : QuantizationConfig ):
1912 """Apply quantization to linear layers."""
2013
2114 def quantize_nn_mod (float_model ):
2215 for name , mod in float_model .named_modules ():
2316 new_mod = None
17+ if config .exclude_layers and name in config .exclude_layers :
18+ continue
2419 if hasattr (mod , "get_quantized_version" ):
2520 new_mod = mod .get_quantized_version ()
2621 elif isinstance (mod , torch .nn .Linear ):
2722 new_mod = create_quantized_from_nn_linear (mod , config )
28- elif isinstance (mod , torch .nn .Embedding ) and _QUANTIZE_EMBEDDING . value :
23+ elif isinstance (mod , torch .nn .Embedding ):
2924 new_mod = create_quantized_from_nn_embedding (mod , config )
3025
3126 if new_mod :
You can’t perform that action at this time.
0 commit comments