-
Notifications
You must be signed in to change notification settings - Fork 31
Add Cache Miss/Hit Test #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
18bbf01
42305bb
fe8c61f
630fe39
fd1c20f
1566583
c33072e
b8181ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,6 +7,8 @@ | |
| import torch | ||
| from torch import distributed as dist | ||
| from torch.fx.experimental import _config as fx_config | ||
| from torch_sendnn.backends.sendnn_backend import _get_global_state | ||
| from torch_sendnn.utils.graph_cache import SpyreGraphCache | ||
|
|
||
| from aiu_fms_testing_utils.testing.validation import ( | ||
| extract_validation_information, | ||
|
|
@@ -29,6 +31,7 @@ | |
| from transformers import AutoTokenizer | ||
|
|
||
| from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup | ||
| import shutil | ||
| import os | ||
|
|
||
| try: | ||
|
|
@@ -132,7 +135,7 @@ | |
| if USE_MICRO_MODELS: | ||
| VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models") | ||
|
|
||
| # pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" | ||
| # pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base" | ||
| if isinstance(COMMON_MODEL_PATHS, str): | ||
| COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",") | ||
|
|
||
|
|
@@ -593,6 +596,8 @@ def _get_device_validation_information( | |
| token_iter, | ||
| ATTN_NAME, | ||
| ) | ||
| if cpu_validation_info is not None: | ||
| return cpu_validation_info | ||
|
|
||
| if cpu_validation_info is not None: | ||
| return cpu_validation_info | ||
|
|
@@ -830,6 +835,8 @@ def _run_cpu_aiu_validation_test( | |
| aiu_model, | ||
| micro_model_path, | ||
| record_property, | ||
| verify_cache_state=None, | ||
| warmup_only=False, | ||
| ): | ||
| # Get the tokenizer and AIU / CPU models to compare | ||
| tokenizer = AutoTokenizer.from_pretrained(model_path) | ||
|
|
@@ -853,6 +860,16 @@ def _run_cpu_aiu_validation_test( | |
| aiu_model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SENDNN, **extra_kwargs | ||
| ) | ||
|
|
||
| # Used only for cache tests; this is a nonparametric closure that | ||
| # should assert the cache for torch sendnn is in the correct state | ||
| # for this test | ||
| if verify_cache_state is not None: | ||
| verify_cache_state() | ||
|
|
||
| # For some tests, e.g., cache checks, we only need to run the warmup | ||
| if warmup_only: | ||
| return | ||
|
|
||
| # Run validation level 0 | ||
| failed_validation_level_0, validation_zero_info = _run_validation_level_0( | ||
| model_path, | ||
|
|
@@ -888,6 +905,86 @@ def _run_cpu_aiu_validation_test( | |
| ) | ||
|
|
||
|
|
||
| def _reset_cache_settings(purge_cache_dir, cache_dir=None): | ||
| os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" | ||
| os.environ["COMPILATION_MODE"] = "offline_decoder" | ||
| if cache_dir is not None: | ||
| # Might be a posixpath | ||
| cache_dir = str(cache_dir) | ||
| os.environ["TORCH_SENDNN_CACHE_DIR"] = cache_dir | ||
|
|
||
| # Ensure we start in clean state | ||
| if purge_cache_dir and os.path.isdir(cache_dir): | ||
| shutil.rmtree(cache_dir) | ||
| os.mkdir(cache_dir) | ||
|
|
||
| # NOTE: currently, the cache dir is pulled from | ||
| # TORCH_SENDNN_CACHE_DIR at initialization time, | ||
| # so this should correctly use the cache_dir | ||
| _get_global_state().use_aiu_cache = True | ||
| _get_global_state().spyre_graph_cache = SpyreGraphCache() | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def use_cached_model(request, persistent_model, record_property, tmp_path): | ||
| """Configures the torchsendnn cache and runs the AIU model (warmup) | ||
| prior to test execution; this is computationally expensive and should | ||
| only be used in situations like testing cache hit correctness. | ||
| """ | ||
| torch.manual_seed(42) | ||
| torch.set_grad_enabled(False) | ||
| _reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path) | ||
|
|
||
| model_path, batch_size, seq_length, max_new_tokens = request.param | ||
| micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) | ||
|
|
||
| def verify_cache_miss(): | ||
| cache_dir = str(tmp_path) | ||
| updated_cache_len = ( | ||
| len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 | ||
| ) | ||
| assert updated_cache_len == max_new_tokens, ( | ||
| "cache directory not populated on cache miss" | ||
| ) | ||
|
|
||
| dprint( | ||
| f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}" | ||
| ) | ||
|
|
||
| # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured | ||
| gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) | ||
| is_gptq = len(gptq_kwargs_aiu) != 0 | ||
| is_fp8 = "fp8" in ATTN_NAME | ||
| model_kwargs = _get_common_model_kwargs(is_gptq, model_path) | ||
|
|
||
| # Get the AIU model w/ the persistent model fixture | ||
| model = persistent_model.get_or_create( | ||
| is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs | ||
| ) | ||
|
|
||
| validation_model = _get_cpu_model( | ||
| is_gptq, | ||
| is_fp8, | ||
| micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, | ||
| **gptq_kwargs_cpu, | ||
| **model_kwargs, | ||
| ) | ||
|
|
||
| _run_cpu_aiu_validation_test( | ||
| model_path, | ||
| batch_size, | ||
| seq_length, | ||
| max_new_tokens, | ||
| validation_model, | ||
| model, | ||
| micro_model_path, | ||
| record_property, | ||
| verify_cache_state=verify_cache_miss, | ||
| warmup_only=True, | ||
| ) | ||
| return request.param | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES | ||
| ) | ||
|
|
@@ -937,3 +1034,64 @@ def test_common_shapes( | |
| micro_model_path, | ||
| record_property, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "use_cached_model", | ||
| COMMON_SHAPES, | ||
| indirect=True, | ||
| ) | ||
| def test_cache(use_cached_model, persistent_model, record_property, tmp_path): | ||
| torch.manual_seed(42) | ||
| torch.set_grad_enabled(False) | ||
| _reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path) | ||
|
|
||
| # use_cached_model is an indirectly parametrized fixture, and the returned | ||
| # value is an expanded tuple from COMMON_SHAPES, so we unpack it here | ||
| model_path, batch_size, seq_length, max_new_tokens = use_cached_model | ||
| micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None) | ||
|
|
||
| def verify_cache_hit(): | ||
| cache_dir = str(tmp_path) | ||
| updated_cache_len = ( | ||
| len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0 | ||
| ) | ||
| assert updated_cache_len == max_new_tokens, ( | ||
| "cache miss occurred when hit was expected" | ||
| ) | ||
|
|
||
| dprint( | ||
| f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit" | ||
| ) | ||
|
|
||
| # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured | ||
| gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) | ||
| is_gptq = len(gptq_kwargs_aiu) != 0 | ||
| is_fp8 = "fp8" in ATTN_NAME | ||
| model_kwargs = _get_common_model_kwargs(is_gptq, model_path) | ||
|
|
||
| # Get the AIU model w/ the persistent model fixture | ||
| model = persistent_model.get_or_create( | ||
|
||
| is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs | ||
| ) | ||
|
|
||
| validation_model = _get_cpu_model( | ||
| is_gptq, | ||
| is_fp8, | ||
| micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None, | ||
| **gptq_kwargs_cpu, | ||
| **model_kwargs, | ||
| ) | ||
|
|
||
| _run_cpu_aiu_validation_test( | ||
| model_path, | ||
| batch_size, | ||
| seq_length, | ||
| max_new_tokens, | ||
| validation_model, | ||
| model, | ||
| micro_model_path, | ||
| record_property, | ||
| verify_cache_state=verify_cache_hit, | ||
| warmup_only=True, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.