Skip to content

Commit b6e36d4

Browse files
Add cache tests back
1 parent eafb818 commit b6e36d4

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

tests/models/test_decoders.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
import json
2626
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
27+
import shutil
2728
import os
2829

2930
try:
@@ -786,6 +787,7 @@ def _run_cpu_aiu_validation_test(
786787
cpu_model,
787788
aiu_model,
788789
micro_model_path,
790+
verify_cache_state=None,
789791
):
790792
# Get the tokenizer and AIU / CPU models to compare
791793
tokenizer = tokenizers.get_tokenizer(model_path)
@@ -811,6 +813,12 @@ def _run_cpu_aiu_validation_test(
811813
aiu_model,
812814
)
813815

816+
# Used only for cache tests; this is a nonparametric closure that
817+
# should assert the cache for torch sendnn is in the correct state
818+
# for this test
819+
if verify_cache_state is not None:
820+
verify_cache_state()
821+
814822
# if level 0 fails validation, validate level 1
815823
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
816824
if failed_validation_level_0:
@@ -832,6 +840,87 @@ def _run_cpu_aiu_validation_test(
832840
)
833841

834842

843+
def _reset_cache_settings(purge_cache_dir):
844+
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
845+
os.environ["COMPILATION_MODE"] = "offline_decoder"
846+
cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"]
847+
848+
# Ensure we start in clean state
849+
if purge_cache_dir and os.path.isdir(cache_dir):
850+
shutil.rmtree(cache_dir)
851+
os.mkdir(cache_dir)
852+
853+
from torch_sendnn.backends import cache
854+
855+
# Explicitly clear cache paths from the global torch sendnn graph;
856+
# TODO would be better to add a helper to explicitly do this in
857+
# torch sendnn
858+
cache.cache = {}
859+
860+
861+
@pytest.fixture
862+
def use_cached_model():
863+
"""Configures the tochsendnn cache and runs the AIU model prior to test execution;
864+
this is computationally expensive and should only be used in situations like testing
865+
cache hit correctness;
866+
"""
867+
torch.manual_seed(42)
868+
torch.set_grad_enabled(False)
869+
_reset_cache_settings(purge_cache_dir=True)
870+
871+
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
872+
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
873+
874+
def verify_cache_miss():
875+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
876+
updated_cache_len = (
877+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
878+
)
879+
assert updated_cache_len == max_new_tokens, (
880+
"cache directory not populated on cache miss"
881+
)
882+
883+
dprint(
884+
f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
885+
)
886+
887+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
888+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
889+
890+
model = _get_aiu_model(
891+
model_path,
892+
gptq_kwargs_aiu,
893+
persistent_model_inst=None,
894+
)
895+
896+
validation_model = _get_cpu_model(
897+
model_path,
898+
gptq_kwargs_cpu,
899+
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
900+
)
901+
902+
_run_cpu_aiu_validation_test(
903+
model_path,
904+
batch_size,
905+
seq_length,
906+
max_new_tokens,
907+
validation_model,
908+
model,
909+
micro_model_path,
910+
verify_cache_state=verify_cache_miss,
911+
)
912+
913+
914+
def _get_cache_test_params():
915+
# NOTE - currently we always use granite 3.3 for the cache test,
916+
# TODO make this configurable as tests are refactored
917+
model_path = GRANITE_3p3_8B_INSTRUCT
918+
batch_size = COMMON_BATCH_SIZES[0]
919+
seq_length = COMMON_SEQ_LENGTHS[0]
920+
max_new_tokens = COMMON_MAX_NEW_TOKENS[0]
921+
return [model_path, batch_size, seq_length, max_new_tokens]
922+
923+
835924
@pytest.mark.parametrize(
836925
"model_path,batch_size,seq_length,max_new_tokens", common_shapes
837926
)
@@ -870,3 +959,51 @@ def test_common_shapes(
870959
model,
871960
micro_model_path,
872961
)
962+
963+
964+
def test_cache(use_cached_model):
965+
torch.manual_seed(42)
966+
torch.set_grad_enabled(False)
967+
_reset_cache_settings(purge_cache_dir=False)
968+
969+
model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params()
970+
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)
971+
972+
def verify_cache_hit():
973+
cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR")
974+
updated_cache_len = (
975+
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
976+
)
977+
assert updated_cache_len == max_new_tokens, (
978+
"cache miss occurred when hit was expected"
979+
)
980+
981+
dprint(
982+
f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit"
983+
)
984+
985+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
986+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
987+
988+
model = _get_aiu_model(
989+
model_path,
990+
gptq_kwargs_aiu,
991+
persistent_model_inst=None,
992+
)
993+
994+
validation_model = _get_cpu_model(
995+
model_path,
996+
gptq_kwargs_cpu,
997+
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
998+
)
999+
1000+
_run_cpu_aiu_validation_test(
1001+
model_path,
1002+
batch_size,
1003+
seq_length,
1004+
max_new_tokens,
1005+
validation_model,
1006+
model,
1007+
micro_model_path,
1008+
verify_cache_state=verify_cache_hit,
1009+
)

0 commit comments

Comments
 (0)