Skip to content

Commit 6bc631c

Browse files
authored
Merge pull request #10 from foundation-model-stack/contiguous-cache
Force kv cache to be contiguous to reduce number of graph traces
2 parents b912264 + e4b2bb8 commit 6bc631c

File tree

5 files changed

+5
-9
lines changed

5 files changed

+5
-9
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,3 @@ aiu-fms-testing-utils.egg-info
77
*/**/*.pyc
88
.vscode
99
aiu-fms-testing-utils.egg-info
10-

aiu_fms_testing_utils/testing/validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import List, Tuple, Callable, MutableMapping, Any, Optional
33

44
import torch
5-
import torch.nn as nn
65
from fms.utils.generation import generate
76
from aiu_fms_testing_utils.utils import ids_for_prompt
87
from aiu_fms_testing_utils.utils.aiu_setup import dprint
@@ -205,6 +204,7 @@ def extract_validation_information(model, input_ids, max_new_tokens, post_iterat
205204
post_iteration_hook=post_iteration_hook,
206205
eos_token_id=eos_token_id,
207206
timing=timing,
207+
contiguous_cache=True,
208208
extra_kwargs=extra_generation_kwargs,
209209
)
210210

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
1515
dprint("AIU warmup")
1616
pt_compile_model_time = time.time()
1717
extra_kwargs = {**padding_kwargs, "only_last_token": True}
18-
generate(model, input_ids, max_new_tokens=max_new_tokens, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, extra_kwargs=extra_kwargs)
18+
generate(model, input_ids, max_new_tokens=max_new_tokens, max_seq_len=model.config.max_expected_seq_len, use_cache=True, do_sample=False, contiguous_cache=True, extra_kwargs=extra_kwargs)
1919
pt_compile_model_time = time.time() - pt_compile_model_time
2020
dprint(f"PT compile complete, took {pt_compile_model_time:.3f}s")
2121

scripts/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ def infer(use_cache, do_sample, warmup):
589589
max_seq_len=max_seq_len,
590590
timing=args.timing,
591591
eos_token_id=eos_token_id,
592+
contiguous_cache=True,
592593
extra_kwargs=extra_generation_kwargs,
593594
)
594595
if args.timing != "":

scripts/validation.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,18 @@
11
import argparse
2-
import itertools
32
import json
43
import os
54
import random
6-
import sys
75
import time
86
from pathlib import Path
9-
from typing import Any, Callable, MutableMapping, Optional, Tuple
10-
import sys
117
import ast
128

139
import numpy as np
1410
import torch
1511
import torch._inductor.config
1612
from fms.models import get_model, register_model
1713
from fms.models.llama import LLaMAConfig, _llama_factory_factory
18-
from fms.utils import fusion, generation, tokenizers
19-
from fms.utils.generation import generate, pad_input_ids
14+
from fms.utils import generation, tokenizers
15+
from fms.utils.generation import pad_input_ids
2016
from torch import distributed as dist
2117
from aiu_fms_testing_utils.utils import warmup_model
2218
from aiu_fms_testing_utils.testing.validation import LogitsExtractorHook, capture_level_1_metrics, extract_validation_information, StaticTokenInjectorHook, GoldenTokenHook, filter_failed_level_1_cases, validate_level_0, load_validation_information, print_failed_cases

0 commit comments

Comments
 (0)