88from pydantic import BaseModel , Field
99
1010from ads .aqua .common .errors import AquaRecommendationError
11- from ads .aqua .shaperecommend .constants import NEXT_QUANT , QUANT_MAPPING
11+ from ads .aqua .shaperecommend .constants import (
12+ BITS_AND_BYTES_4BIT ,
13+ BITS_AND_BYTES_8BIT ,
14+ DEFAULT_WEIGHT_SIZE ,
15+ NEXT_QUANT ,
16+ QUANT_MAPPING ,
17+ )
1218
1319
1420class LLMConfig (BaseModel ):
@@ -35,10 +41,11 @@ class LLMConfig(BaseModel):
3541 description = "Dimension of each attention head. Typically hidden_size // num_attention_heads." ,
3642 )
3743 max_seq_len : Optional [int ] = Field (
38- 8192 , description = "Maximum input sequence length (context window)."
44+ 4096 , description = "Maximum input sequence length (context window)."
3945 )
4046 weight_dtype : Optional [str ] = Field (
41- "float32" , description = "Parameter data type: 'float32', 'float16', etc."
47+ DEFAULT_WEIGHT_SIZE ,
48+ description = "Parameter data type: 'float32', 'float16', etc." ,
4249 )
4350 quantization : Optional [str ] = Field (
4451 None ,
@@ -49,6 +56,11 @@ class LLMConfig(BaseModel):
4956 description = "Quantization method (e.g., '8bit', '4bit', 'gptq', 'awq') or None if unquantized." ,
5057 )
5158
59+ in_flight_quantization : Optional [str ] = Field (
60+ None ,
61+ description = "By setting this, enables recalculation of model footprint using 4bit in-flight quantization" ,
62+ )
63+
5264 num_key_value_heads : Optional [int ] = Field (
5365 None ,
5466 description = "Number of key/value heads (for GQA architectures: Llama, Mistral, Falcon, Qwen, etc.). Used to determine KV cache size" ,
@@ -82,9 +94,13 @@ def bytes_per_parameter(self) -> float:
8294 bits = int (m [1 ])
8395 return bits / 8 # bytes per parameter
8496
97+ # consider in-flight quantization
98+ if self .in_flight_quantization in QUANT_MAPPING :
99+ return QUANT_MAPPING [self .in_flight_quantization ]
100+
85101 # Fallback to dtype mapping
86- dtype = (self .weight_dtype or "float32" ).lower ()
87- return QUANT_MAPPING .get (dtype , QUANT_MAPPING ["float32" ])
102+ dtype = (self .weight_dtype or DEFAULT_WEIGHT_SIZE ).lower ()
103+ return QUANT_MAPPING .get (dtype , QUANT_MAPPING [DEFAULT_WEIGHT_SIZE ])
88104
89105 @classmethod
90106 def detect_quantization_type (cls , raw : dict ) -> Optional [str ]:
@@ -114,9 +130,9 @@ def detect_quantization_bits(cls, raw: dict) -> Optional[str]:
114130 Detects quantization bit-width as a string (e.g., '4bit', '8bit') from Hugging Face config dict.
115131 """
116132 if raw .get ("load_in_8bit" ):
117- return "8bit"
133+ return BITS_AND_BYTES_8BIT
118134 if raw .get ("load_in_4bit" ):
119- return "4bit"
135+ return BITS_AND_BYTES_4BIT
120136 if "quantization_config" in raw :
121137 qcfg = raw ["quantization_config" ]
122138 bits = qcfg .get ("bits" ) or qcfg .get ("wbits" )
@@ -132,7 +148,12 @@ def suggested_quantizations(self):
132148 If model is un-quantized, uses the weight size.
133149 If model is pre-quantized, uses the quantization level.
134150 """
135- key = (self .quantization or self .weight_dtype or "float32" ).lower ()
151+ key = (
152+ self .quantization
153+ or self .in_flight_quantization
154+ or self .weight_dtype
155+ or DEFAULT_WEIGHT_SIZE
156+ ).lower ()
136157 return NEXT_QUANT .get (key , [])
137158
138159 def calculate_possible_seq_len (self , min_len = 2048 ):
@@ -142,22 +163,21 @@ def calculate_possible_seq_len(self, min_len=2048):
142163 """
143164 vals = []
144165 curr = min_len
145- max_seq_len = 16384 if not self .max_seq_len else self .max_seq_len
146- while curr <= max_seq_len :
166+ while curr <= self .max_seq_len :
147167 vals .append (curr )
148168 curr *= 2
149- if vals and vals [- 1 ] != max_seq_len :
150- vals .append (max_seq_len )
169+ if vals and vals [- 1 ] != self . max_seq_len :
170+ vals .append (self . max_seq_len )
151171 return vals
152172
153173 def optimal_config (self ):
154174 """
155175 Builds a list of optimal configuration parameters (sorted descending). Combination of:
156- - Quantization / weight sizes: bfloat16 weight size -> 8bit -> 4bit
176+ - Quantization / weight sizes: bfloat16 weight size -> 4bit
157177 - max-model-len: power-of-two model lengths from max length (config.json of model) to 2048 tokens.
158178
159179 Example:
160- [('bfloat16', max_model_len supported by model) ('bfloat16', 1/2 of max_model_len) ... ('int8', 2048), (' int4', 4096), ('int4', 2048)]
180+ [('bfloat16', max_model_len supported by model) ('bfloat16', 1/2 of max_model_len) ... ('int4', 4096), ('int4', 2048)]
161181
162182 """
163183 # Create a copy of the suggested_quantizations list
@@ -183,9 +203,11 @@ def validate_model_support(cls, raw: dict) -> ValueError:
183203 """
184204 excluded_models = {"t5" , "gemma" , "bart" , "bert" , "roberta" , "albert" }
185205 if (
186- raw .get ("is_encoder_decoder" , False ) # exclude encoder-decoder models
187- or (raw .get ("is_decoder" ) is False ) # exclude explicit encoder-only models (altho no text-generation task ones, just dbl check)
188- or raw .get ("model_type" , "" ).lower () # exclude by known model types
206+ raw .get ("is_encoder_decoder" , False ) # exclude encoder-decoder models
207+ or (
208+ raw .get ("is_decoder" ) is False
209+ ) # exclude explicit encoder-only models (altho no text-generation task ones, just dbl check)
210+ or raw .get ("model_type" , "" ).lower () # exclude by known model types
189211 in excluded_models
190212 ):
191213 raise AquaRecommendationError (
@@ -207,7 +229,7 @@ def from_raw_config(cls, raw: dict) -> "LLMConfig":
207229 )
208230 hidden_size = raw .get ("hidden_size" ) or raw .get ("n_embd" ) or raw .get ("d_model" )
209231 vocab_size = raw .get ("vocab_size" )
210- weight_dtype = str (raw .get ("torch_dtype" , "float32" ))
232+ weight_dtype = str (raw .get ("torch_dtype" , DEFAULT_WEIGHT_SIZE ))
211233 quantization = cls .detect_quantization_bits (raw )
212234 quantization_type = cls .detect_quantization_type (raw )
213235
0 commit comments