Skip to content

Commit 18bbf01

Browse files
Add cache test
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 281ff22 commit 18bbf01

File tree

1 file changed

+148
-1
lines changed

1 file changed

+148
-1
lines changed

tests/models/test_decoders.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
from torch import distributed as dist
99
from torch.fx.experimental import _config as fx_config
10+
from torch_sendnn.backends.sendnn_backend import _get_global_state
11+
from torch_sendnn.utils.graph_cache import SpyreGraphCache
1012

1113
from aiu_fms_testing_utils.testing.validation import (
1214
extract_validation_information,
@@ -29,6 +31,7 @@
2931
from transformers import AutoTokenizer
3032

3133
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
34+
import shutil
3235
import os
3336

3437
try:
@@ -132,7 +135,7 @@
132135
if USE_MICRO_MODELS:
133136
VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models")
134137

135-
# pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
138+
# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
136139
if isinstance(COMMON_MODEL_PATHS, str):
137140
COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",")
138141

@@ -593,6 +596,8 @@ def _get_device_validation_information(
593596
token_iter,
594597
ATTN_NAME,
595598
)
599+
if cpu_validation_info is not None:
600+
return cpu_validation_info
596601

597602
if cpu_validation_info is not None:
598603
return cpu_validation_info
@@ -830,6 +835,7 @@ def _run_cpu_aiu_validation_test(
830835
aiu_model,
831836
micro_model_path,
832837
record_property,
838+
verify_cache_state=None,
833839
):
834840
# Get the tokenizer and AIU / CPU models to compare
835841
tokenizer = AutoTokenizer.from_pretrained(model_path)
@@ -866,6 +872,12 @@ def _run_cpu_aiu_validation_test(
866872
aiu_model,
867873
)
868874

875+
# Used only for cache tests; this is a nonparametric closure that
876+
# should assert the cache for torch sendnn is in the correct state
877+
# for this test
878+
if verify_cache_state is not None:
879+
verify_cache_state()
880+
869881
# if level 0 fails validation, validate level 1
870882
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
871883
if failed_validation_level_0:
@@ -888,6 +900,88 @@ def _run_cpu_aiu_validation_test(
888900
)
889901

890902

903+
def _get_cache_test_params():
904+
# NOTE - currently we always use granite 3.3 for the cache test,
905+
# TODO make this configurable as tests are refactored
906+
model_path = GRANITE_3p3_8B_INSTRUCT
907+
batch_size = COMMON_BATCH_SIZES[0]
908+
seq_length = COMMON_SEQ_LENGTHS[0]
909+
max_new_tokens = COMMON_MAX_NEW_TOKENS[0]
910+
return [model_path, batch_size, seq_length, max_new_tokens]
911+
912+
913+
def _reset_cache_settings(purge_cache_dir):
914+
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
915+
os.environ["COMPILATION_MODE"] = "offline_decoder"
916+
cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"]
917+
918+
# Ensure we start in clean state
919+
if purge_cache_dir and os.path.isdir(cache_dir):
920+
shutil.rmtree(cache_dir)
921+
os.mkdir(cache_dir)
922+
923+
_get_global_state().use_aiu_cache = True
924+
_get_global_state().spyre_graph_cache = SpyreGraphCache()
925+
926+
927+
@pytest.fixture
928+
def use_cached_model(persistent_model, record_property):
929+
"""Configures the torchsendnn cache and runs the AIU model prior to test execution;
930+
this is computationally expensive and should only be used in situations like testing
931+
cache hit correctness;
932+
"""
933+
torch.manual_seed(42)
934+
torch.set_grad_enabled(False)
935+
_reset_cache_settings(purge_cache_dir=True)
936+
937+
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
938+
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
939+
940+
def verify_cache_miss():
941+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
942+
updated_cache_len = (
943+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
944+
)
945+
assert updated_cache_len == max_new_tokens, (
946+
"cache directory not populated on cache miss"
947+
)
948+
949+
dprint(
950+
f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
951+
)
952+
953+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
954+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
955+
is_gptq = len(gptq_kwargs_aiu) != 0
956+
is_fp8 = "fp8" in ATTN_NAME
957+
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
958+
959+
# Get the AIU model w/ the persistent model fixture
960+
model = persistent_model.get_or_create(
961+
is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs
962+
)
963+
964+
validation_model = _get_cpu_model(
965+
is_gptq,
966+
is_fp8,
967+
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
968+
**gptq_kwargs_cpu,
969+
**model_kwargs,
970+
)
971+
972+
_run_cpu_aiu_validation_test(
973+
model_path,
974+
batch_size,
975+
seq_length,
976+
max_new_tokens,
977+
validation_model,
978+
model,
979+
micro_model_path,
980+
record_property,
981+
verify_cache_state=verify_cache_miss,
982+
)
983+
984+
891985
@pytest.mark.parametrize(
892986
"model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES
893987
)
@@ -937,3 +1031,56 @@ def test_common_shapes(
9371031
micro_model_path,
9381032
record_property,
9391033
)
1034+
1035+
1036+
def test_cache(use_cached_model, persistent_model, record_property):
1037+
torch.manual_seed(42)
1038+
torch.set_grad_enabled(False)
1039+
_reset_cache_settings(purge_cache_dir=False)
1040+
1041+
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
1042+
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
1043+
1044+
def verify_cache_hit():
1045+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
1046+
updated_cache_len = (
1047+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
1048+
)
1049+
assert updated_cache_len == max_new_tokens, (
1050+
"cache miss occurred when hit was expected"
1051+
)
1052+
1053+
dprint(
1054+
f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit"
1055+
)
1056+
1057+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
1058+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
1059+
is_gptq = len(gptq_kwargs_aiu) != 0
1060+
is_fp8 = "fp8" in ATTN_NAME
1061+
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)
1062+
1063+
# Get the AIU model w/ the persistent model fixture
1064+
model = persistent_model.get_or_create(
1065+
is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs
1066+
)
1067+
1068+
validation_model = _get_cpu_model(
1069+
is_gptq,
1070+
is_fp8,
1071+
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
1072+
**gptq_kwargs_cpu,
1073+
**model_kwargs,
1074+
)
1075+
1076+
_run_cpu_aiu_validation_test(
1077+
model_path,
1078+
batch_size,
1079+
seq_length,
1080+
max_new_tokens,
1081+
validation_model,
1082+
model,
1083+
micro_model_path,
1084+
record_property,
1085+
verify_cache_state=verify_cache_hit,
1086+
)

0 commit comments

Comments
 (0)