Skip to content

Commit 5e84730

Browse files
authored
Merge pull request #167 from foundation-model-stack/chunked_fp8
Force static kv cache for DPP+fp8+chunked prefill
2 parents 42161e1 + 93e0c09 commit 5e84730

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

aiu_fms_testing_utils/utils/paged.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def generate(
165165
if "fp8" in kwargs["attn_name"]:
166166
from fms_mo.aiu_addons.fp8.fp8_utils import ScaledTensor
167167

168+
already_scaled = prefill_chunk_size > 0
169+
168170
kwargs["past_key_value_states"] = [
169171
(
170172
ScaledTensor(
@@ -176,7 +178,7 @@ def generate(
176178
dtype=torch.float8_e4m3fn,
177179
),
178180
torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32),
179-
False,
181+
already_scaled,
180182
),
181183
ScaledTensor(
182184
torch.zeros(
@@ -187,7 +189,7 @@ def generate(
187189
dtype=torch.float8_e4m3fn,
188190
),
189191
torch.tensor([1.0] * input_ids.shape[0], dtype=torch.float32),
190-
False,
192+
already_scaled,
191193
),
192194
)
193195
for _ in range(model.config.nlayers)
@@ -421,7 +423,7 @@ def generate(
421423
current_kv_scales[layer_idx][0][seq_i] = t1._scale
422424
current_kv_scales[layer_idx][1][seq_i] = t2._scale
423425

424-
if seq_i != input_ids.size(0) - 1:
426+
if seq_i != input_ids.size(0) - 1 and prefill_chunk_size == 0:
425427
for layer_cache in current_kv_cache:
426428
layer_cache[0]._scaled = False
427429
layer_cache[1]._scaled = False

0 commit comments

Comments
 (0)