File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed
aiu_fms_testing_utils/utils Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -165,6 +165,8 @@ def generate(
165165 if "fp8" in kwargs ["attn_name" ]:
166166 from fms_mo .aiu_addons .fp8 .fp8_utils import ScaledTensor
167167
168+ already_scaled = prefill_chunk_size > 0
169+
168170 kwargs ["past_key_value_states" ] = [
169171 (
170172 ScaledTensor (
@@ -176,7 +178,7 @@ def generate(
176178 dtype = torch .float8_e4m3fn ,
177179 ),
178180 torch .tensor ([1.0 ] * input_ids .shape [0 ], dtype = torch .float32 ),
179- False ,
181+ already_scaled ,
180182 ),
181183 ScaledTensor (
182184 torch .zeros (
@@ -187,7 +189,7 @@ def generate(
187189 dtype = torch .float8_e4m3fn ,
188190 ),
189191 torch .tensor ([1.0 ] * input_ids .shape [0 ], dtype = torch .float32 ),
190- False ,
192+ already_scaled ,
191193 ),
192194 )
193195 for _ in range (model .config .nlayers )
@@ -421,7 +423,7 @@ def generate(
421423 current_kv_scales [layer_idx ][0 ][seq_i ] = t1 ._scale
422424 current_kv_scales [layer_idx ][1 ][seq_i ] = t2 ._scale
423425
424- if seq_i != input_ids .size (0 ) - 1 :
426+ if seq_i != input_ids .size (0 ) - 1 and prefill_chunk_size == 0 :
425427 for layer_cache in current_kv_cache :
426428 layer_cache [0 ]._scaled = False
427429 layer_cache [1 ]._scaled = False
You can’t perform that action at this time.
0 commit comments