From 18bbf015538f8211920ec43f4237914e1ff2144f Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:25:47 +0000 Subject: [PATCH 1/8] Add cache test Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 149 +++++++++++++++++++++++++++++++++- 1 file changed, 148 insertions(+), 1 deletion(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 4f95e61e..5aa64aa5 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -7,6 +7,8 @@ import torch from torch import distributed as dist from torch.fx.experimental import _config as fx_config +from torch_sendnn.backends.sendnn_backend import _get_global_state +from torch_sendnn.utils.graph_cache import SpyreGraphCache from aiu_fms_testing_utils.testing.validation import ( extract_validation_information, @@ -29,6 +31,7 @@ from transformers import AutoTokenizer from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup +import shutil import os try: @@ -132,7 +135,7 @@ if USE_MICRO_MODELS: VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models") -# pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" +# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" if isinstance(COMMON_MODEL_PATHS, str): COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",") @@ -593,6 +596,8 @@ def _get_device_validation_information( token_iter, ATTN_NAME, ) + if cpu_validation_info is not None: + return cpu_validation_info if cpu_validation_info is not None: return cpu_validation_info @@ -830,6 +835,7 @@ def _run_cpu_aiu_validation_test( aiu_model, micro_model_path, record_property, + verify_cache_state=None, ): # Get the tokenizer and AIU / CPU models to compare tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -866,6 +872,12 @@ def _run_cpu_aiu_validation_test( aiu_model, ) + # Used only for cache tests; this is a nonparametric closure that + # should assert the cache for torch sendnn is in the correct state + # for this test + if verify_cache_state is not None: + verify_cache_state() + # if level 0 fails validation, validate level 1 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: if failed_validation_level_0: @@ -888,6 +900,88 @@ def _run_cpu_aiu_validation_test( ) +def _get_cache_test_params(): + # NOTE - currently we always use granite 3.3 for the cache test, + # TODO make this configurable as tests are refactored + model_path = GRANITE_3p3_8B_INSTRUCT + batch_size = COMMON_BATCH_SIZES[0] + seq_length = COMMON_SEQ_LENGTHS[0] + max_new_tokens = COMMON_MAX_NEW_TOKENS[0] + return [model_path, batch_size, seq_length, max_new_tokens] + + +def _reset_cache_settings(purge_cache_dir): + os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" + os.environ["COMPILATION_MODE"] = "offline_decoder" + cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"] + + # Ensure we start in clean state + if purge_cache_dir and os.path.isdir(cache_dir): + shutil.rmtree(cache_dir) + os.mkdir(cache_dir) + + _get_global_state().use_aiu_cache = True + _get_global_state().spyre_graph_cache = SpyreGraphCache() + + +@pytest.fixture +def use_cached_model(persistent_model, record_property): + """Configures the torchsendnn cache and runs the AIU model prior to test execution; + this is computationally expensive and should only be used in situations like testing + cache hit correctness; + """ + torch.manual_seed(42) + torch.set_grad_enabled(False) + _reset_cache_settings(purge_cache_dir=True) + + model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) + + def verify_cache_miss(): + cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + updated_cache_len = ( + len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 + ) + assert updated_cache_len == max_new_tokens, ( + "cache directory not populated on cache miss" + ) + + dprint( + f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}" + ) + + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) + is_gptq = len(gptq_kwargs_aiu) != 0 + is_fp8 = "fp8" in ATTN_NAME + model_kwargs = _get_common_model_kwargs(is_gptq, model_path) + + # Get the AIU model w/ the persistent model fixture + model = persistent_model.get_or_create( + is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs + ) + + validation_model = _get_cpu_model( + is_gptq, + is_fp8, + micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, + **gptq_kwargs_cpu, + **model_kwargs, + ) + + _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + validation_model, + model, + micro_model_path, + record_property, + verify_cache_state=verify_cache_miss, + ) + + @pytest.mark.parametrize( "model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES ) @@ -937,3 +1031,56 @@ def test_common_shapes( micro_model_path, record_property, ) + + +def test_cache(use_cached_model, persistent_model, record_property): + torch.manual_seed(42) + torch.set_grad_enabled(False) + _reset_cache_settings(purge_cache_dir=False) + + model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) + + def verify_cache_hit(): + cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + updated_cache_len = ( + len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 + ) + assert updated_cache_len == max_new_tokens, ( + "cache miss occurred when hit was expected" + ) + + dprint( + f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit" + ) + + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) + is_gptq = len(gptq_kwargs_aiu) != 0 + is_fp8 = "fp8" in ATTN_NAME + model_kwargs = _get_common_model_kwargs(is_gptq, model_path) + + # Get the AIU model w/ the persistent model fixture + model = persistent_model.get_or_create( + is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs + ) + + validation_model = _get_cpu_model( + is_gptq, + is_fp8, + micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, + **gptq_kwargs_cpu, + **model_kwargs, + ) + + _run_cpu_aiu_validation_test( + model_path, + batch_size, + seq_length, + max_new_tokens, + validation_model, + model, + micro_model_path, + record_property, + verify_cache_state=verify_cache_hit, + ) From 42305bb7534736bbbf64021bc54ca453b3a6c425 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:41:34 +0000 Subject: [PATCH 2/8] use tmp_path fixture for cache test Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 5aa64aa5..575e5317 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -910,29 +910,35 @@ def _get_cache_test_params(): return [model_path, batch_size, seq_length, max_new_tokens] -def _reset_cache_settings(purge_cache_dir): +def _reset_cache_settings(purge_cache_dir, cache_dir=None): os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" os.environ["COMPILATION_MODE"] = "offline_decoder" - cache_dir = os.environ["TORCH_SENDNN_CACHE_DIR"] + if cache_dir is not None: + # Might be a posixpath + cache_dir = str(cache_dir) + os.environ["TORCH_SENDNN_CACHE_DIR"] = cache_dir # Ensure we start in clean state if purge_cache_dir and os.path.isdir(cache_dir): shutil.rmtree(cache_dir) os.mkdir(cache_dir) + # NOTE: currently, the cache dir is pulled from + # TORCH_SENDNN_CACHE_DIR at initialization time, + # so this should correctly use the cache_dir _get_global_state().use_aiu_cache = True _get_global_state().spyre_graph_cache = SpyreGraphCache() @pytest.fixture -def use_cached_model(persistent_model, record_property): +def use_cached_model(persistent_model, record_property, tmp_path): """Configures the torchsendnn cache and runs the AIU model prior to test execution; this is computationally expensive and should only be used in situations like testing cache hit correctness; """ torch.manual_seed(42) torch.set_grad_enabled(False) - _reset_cache_settings(purge_cache_dir=True) + _reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path) model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) @@ -1033,10 +1039,10 @@ def test_common_shapes( ) -def test_cache(use_cached_model, persistent_model, record_property): +def test_cache(use_cached_model, persistent_model, record_property, tmp_path): torch.manual_seed(42) torch.set_grad_enabled(False) - _reset_cache_settings(purge_cache_dir=False) + _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) From fe8c61f5d865a925b768091ebe874add45a645ac Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:48:35 +0000 Subject: [PATCH 3/8] fix cache_dir in cache checks Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 575e5317..1a5b5f31 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -944,7 +944,7 @@ def use_cached_model(persistent_model, record_property, tmp_path): micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_miss(): - cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + cache_dir = str(tmp_path) updated_cache_len = ( len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 ) @@ -1048,7 +1048,7 @@ def test_cache(use_cached_model, persistent_model, record_property, tmp_path): micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_hit(): - cache_dir = os.environ.get("TORCH_SENDNN_CACHE_DIR") + cache_dir = str(tmp_path) updated_cache_len = ( len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 ) From 630fe39abc500e72dfcb9b1afb8065e802cb0ffa Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 08:59:06 +0000 Subject: [PATCH 4/8] only warmup on cache tests Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 1a5b5f31..5075b888 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -836,6 +836,7 @@ def _run_cpu_aiu_validation_test( micro_model_path, record_property, verify_cache_state=None, + warmup_only=False, ): # Get the tokenizer and AIU / CPU models to compare tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -859,6 +860,16 @@ def _run_cpu_aiu_validation_test( aiu_model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SENDNN, **extra_kwargs ) + # Used only for cache tests; this is a nonparametric closure that + # should assert the cache for torch sendnn is in the correct state + # for this test + if verify_cache_state is not None: + verify_cache_state() + + # For some tests, e.g., cache checks, we only need to run the warmup + if warmup_only: + return + # Run validation level 0 failed_validation_level_0, validation_zero_info = _run_validation_level_0( model_path, @@ -872,12 +883,6 @@ def _run_cpu_aiu_validation_test( aiu_model, ) - # Used only for cache tests; this is a nonparametric closure that - # should assert the cache for torch sendnn is in the correct state - # for this test - if verify_cache_state is not None: - verify_cache_state() - # if level 0 fails validation, validate level 1 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: if failed_validation_level_0: From fd1c20f581b9b0255b573884f31c37be8c1001ea Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 09:11:15 +0000 Subject: [PATCH 5/8] parametrize use_cache Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 5075b888..27178aa3 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -936,7 +936,7 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None): @pytest.fixture -def use_cached_model(persistent_model, record_property, tmp_path): +def use_cached_model(request, persistent_model, record_property, tmp_path): """Configures the torchsendnn cache and runs the AIU model prior to test execution; this is computationally expensive and should only be used in situations like testing cache hit correctness; @@ -990,7 +990,9 @@ def verify_cache_miss(): micro_model_path, record_property, verify_cache_state=verify_cache_miss, + warmup_only=True, ) + return request.param @pytest.mark.parametrize( @@ -1044,12 +1046,19 @@ def test_common_shapes( ) +@pytest.mark.parametrize( + "use_cached_model", + COMMON_SHAPES, + indirect=True, +) def test_cache(use_cached_model, persistent_model, record_property, tmp_path): torch.manual_seed(42) torch.set_grad_enabled(False) _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) - model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + # use_cached_model is an indirectly parametrized fixture, and the returned + # value is an expanded tuple from COMMON_SHAPES, so we unpack it here + model_path, batch_size, seq_length, max_new_tokens = use_cached_model micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_hit(): @@ -1094,4 +1103,5 @@ def verify_cache_hit(): micro_model_path, record_property, verify_cache_state=verify_cache_hit, + warmup_only=True, ) From 15665834a038a5633fafdf134e4c0f0b8aa3efd8 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 13 Oct 2025 09:33:24 +0000 Subject: [PATCH 6/8] use request param for setting up use_cache Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index 27178aa3..ad90dcdb 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -905,16 +905,6 @@ def _run_cpu_aiu_validation_test( ) -def _get_cache_test_params(): - # NOTE - currently we always use granite 3.3 for the cache test, - # TODO make this configurable as tests are refactored - model_path = GRANITE_3p3_8B_INSTRUCT - batch_size = COMMON_BATCH_SIZES[0] - seq_length = COMMON_SEQ_LENGTHS[0] - max_new_tokens = COMMON_MAX_NEW_TOKENS[0] - return [model_path, batch_size, seq_length, max_new_tokens] - - def _reset_cache_settings(purge_cache_dir, cache_dir=None): os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" os.environ["COMPILATION_MODE"] = "offline_decoder" @@ -937,15 +927,15 @@ def _reset_cache_settings(purge_cache_dir, cache_dir=None): @pytest.fixture def use_cached_model(request, persistent_model, record_property, tmp_path): - """Configures the torchsendnn cache and runs the AIU model prior to test execution; - this is computationally expensive and should only be used in situations like testing - cache hit correctness; + """Configures the torchsendnn cache and runs the AIU model (warmup) + prior to test execution; this is computationally expensive and should + only be used in situations like testing cache hit correctness. """ torch.manual_seed(42) torch.set_grad_enabled(False) _reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path) - model_path, batch_size, seq_length, max_new_tokens = _get_cache_test_params() + model_path, batch_size, seq_length, max_new_tokens = request.param micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_miss(): From c33072e696bc15f8804b2c1581da849fa85bbce6 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 26 Oct 2025 11:57:24 +0000 Subject: [PATCH 7/8] reuse aiu/cpu models from cache miss fixture Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 34 +++++++++++----------------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index ad90dcdb..a336836d 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -969,6 +969,8 @@ def verify_cache_miss(): **gptq_kwargs_cpu, **model_kwargs, ) + # We also return the models so that we can reuse them in the cache hit check + models = (model, validation_model) _run_cpu_aiu_validation_test( model_path, @@ -982,7 +984,7 @@ def verify_cache_miss(): verify_cache_state=verify_cache_miss, warmup_only=True, ) - return request.param + return request.param, models @pytest.mark.parametrize( @@ -1041,14 +1043,19 @@ def test_common_shapes( COMMON_SHAPES, indirect=True, ) -def test_cache(use_cached_model, persistent_model, record_property, tmp_path): +def test_cache(use_cached_model, record_property, tmp_path): torch.manual_seed(42) torch.set_grad_enabled(False) _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) # use_cached_model is an indirectly parametrized fixture, and the returned - # value is an expanded tuple from COMMON_SHAPES, so we unpack it here - model_path, batch_size, seq_length, max_new_tokens = use_cached_model + # value is an expanded tuple from COMMON_SHAPES, so we unpack it here. + # In addition, we also pass the model created on AIU in the fixture to + # avoid recreating it. + test_params, models = use_cached_model + model, validation_model = models + model_path, batch_size, seq_length, max_new_tokens = test_params + micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) def verify_cache_hit(): @@ -1064,25 +1071,6 @@ def verify_cache_hit(): f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit" ) - # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured - gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) - is_gptq = len(gptq_kwargs_aiu) != 0 - is_fp8 = "fp8" in ATTN_NAME - model_kwargs = _get_common_model_kwargs(is_gptq, model_path) - - # Get the AIU model w/ the persistent model fixture - model = persistent_model.get_or_create( - is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs - ) - - validation_model = _get_cpu_model( - is_gptq, - is_fp8, - micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, - **gptq_kwargs_cpu, - **model_kwargs, - ) - _run_cpu_aiu_validation_test( model_path, batch_size, From b8181ac08f3ebe3a60beef0466bf68faf5e4571b Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 2 Nov 2025 04:09:25 +0000 Subject: [PATCH 8/8] remove duplicate code Signed-off-by: Alex-Brooks --- tests/models/test_decoders.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/test_decoders.py b/tests/models/test_decoders.py index a336836d..dd20cfa7 100644 --- a/tests/models/test_decoders.py +++ b/tests/models/test_decoders.py @@ -599,9 +599,6 @@ def _get_device_validation_information( if cpu_validation_info is not None: return cpu_validation_info - if cpu_validation_info is not None: - return cpu_validation_info - # overrides for validation info that are device specific device_dependent_kwargs = {} if device == "cpu":