77import torch
88from torch import distributed as dist
99from torch .fx .experimental import _config as fx_config
10+ from torch_sendnn .backends .sendnn_backend import _get_global_state
11+ from torch_sendnn .utils .graph_cache import SpyreGraphCache
1012
1113from aiu_fms_testing_utils .testing .validation import (
1214 extract_validation_information ,
2931from transformers import AutoTokenizer
3032
3133from aiu_fms_testing_utils .utils .aiu_setup import dprint , aiu_dist_setup
34+ import shutil
3235import os
3336
3437try :
132135if USE_MICRO_MODELS :
133136 VALIDATION_INFO_DIR = os .path .join (VALIDATION_INFO_DIR , "tiny_models" )
134137
135- # pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS ="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
138+ # pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS ="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
136139if isinstance (COMMON_MODEL_PATHS , str ):
137140 COMMON_MODEL_PATHS = COMMON_MODEL_PATHS .split ("," )
138141
@@ -593,6 +596,8 @@ def _get_device_validation_information(
593596 token_iter ,
594597 ATTN_NAME ,
595598 )
599+ if cpu_validation_info is not None :
600+ return cpu_validation_info
596601
597602 if cpu_validation_info is not None :
598603 return cpu_validation_info
@@ -830,6 +835,7 @@ def _run_cpu_aiu_validation_test(
830835 aiu_model ,
831836 micro_model_path ,
832837 record_property ,
838+ verify_cache_state = None ,
833839):
834840 # Get the tokenizer and AIU / CPU models to compare
835841 tokenizer = AutoTokenizer .from_pretrained (model_path )
@@ -866,6 +872,12 @@ def _run_cpu_aiu_validation_test(
866872 aiu_model ,
867873 )
868874
875+ # Used only for cache tests; this is a nonparametric closure that
876+ # should assert the cache for torch sendnn is in the correct state
877+ # for this test
878+ if verify_cache_state is not None :
879+ verify_cache_state ()
880+
869881 # if level 0 fails validation, validate level 1
870882 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0 :
871883 if failed_validation_level_0 :
@@ -888,6 +900,88 @@ def _run_cpu_aiu_validation_test(
888900 )
889901
890902
903+ def _get_cache_test_params ():
904+ # NOTE - currently we always use granite 3.3 for the cache test,
905+ # TODO make this configurable as tests are refactored
906+ model_path = GRANITE_3p3_8B_INSTRUCT
907+ batch_size = COMMON_BATCH_SIZES [0 ]
908+ seq_length = COMMON_SEQ_LENGTHS [0 ]
909+ max_new_tokens = COMMON_MAX_NEW_TOKENS [0 ]
910+ return [model_path , batch_size , seq_length , max_new_tokens ]
911+
912+
913+ def _reset_cache_settings (purge_cache_dir ):
914+ os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
915+ os .environ ["COMPILATION_MODE" ] = "offline_decoder"
916+ cache_dir = os .environ ["TORCH_SENDNN_CACHE_DIR" ]
917+
918+ # Ensure we start in clean state
919+ if purge_cache_dir and os .path .isdir (cache_dir ):
920+ shutil .rmtree (cache_dir )
921+ os .mkdir (cache_dir )
922+
923+ _get_global_state ().use_aiu_cache = True
924+ _get_global_state ().spyre_graph_cache = SpyreGraphCache ()
925+
926+
927+ @pytest .fixture
928+ def use_cached_model (persistent_model , record_property ):
929+ """Configures the torchsendnn cache and runs the AIU model prior to test execution;
930+ this is computationally expensive and should only be used in situations like testing
931+ cache hit correctness;
932+ """
933+ torch .manual_seed (42 )
934+ torch .set_grad_enabled (False )
935+ _reset_cache_settings (purge_cache_dir = True )
936+
937+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
938+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
939+
940+ def verify_cache_miss ():
941+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
942+ updated_cache_len = (
943+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
944+ )
945+ assert updated_cache_len == max_new_tokens , (
946+ "cache directory not populated on cache miss"
947+ )
948+
949+ dprint (
950+ f"Setting up cache [i.e., cache miss check] for model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
951+ )
952+
953+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
954+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
955+ is_gptq = len (gptq_kwargs_aiu ) != 0
956+ is_fp8 = "fp8" in ATTN_NAME
957+ model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
958+
959+ # Get the AIU model w/ the persistent model fixture
960+ model = persistent_model .get_or_create (
961+ is_gptq , is_fp8 , ** gptq_kwargs_aiu , ** model_kwargs
962+ )
963+
964+ validation_model = _get_cpu_model (
965+ is_gptq ,
966+ is_fp8 ,
967+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
968+ ** gptq_kwargs_cpu ,
969+ ** model_kwargs ,
970+ )
971+
972+ _run_cpu_aiu_validation_test (
973+ model_path ,
974+ batch_size ,
975+ seq_length ,
976+ max_new_tokens ,
977+ validation_model ,
978+ model ,
979+ micro_model_path ,
980+ record_property ,
981+ verify_cache_state = verify_cache_miss ,
982+ )
983+
984+
891985@pytest .mark .parametrize (
892986 "model_path,batch_size,seq_length,max_new_tokens" , COMMON_SHAPES
893987)
@@ -937,3 +1031,56 @@ def test_common_shapes(
9371031 micro_model_path ,
9381032 record_property ,
9391033 )
1034+
1035+
1036+ def test_cache (use_cached_model , persistent_model , record_property ):
1037+ torch .manual_seed (42 )
1038+ torch .set_grad_enabled (False )
1039+ _reset_cache_settings (purge_cache_dir = False )
1040+
1041+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
1042+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
1043+
1044+ def verify_cache_hit ():
1045+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
1046+ updated_cache_len = (
1047+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
1048+ )
1049+ assert updated_cache_len == max_new_tokens , (
1050+ "cache miss occurred when hit was expected"
1051+ )
1052+
1053+ dprint (
1054+ f"testing: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , for cache hit"
1055+ )
1056+
1057+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
1058+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
1059+ is_gptq = len (gptq_kwargs_aiu ) != 0
1060+ is_fp8 = "fp8" in ATTN_NAME
1061+ model_kwargs = _get_common_model_kwargs (is_gptq , model_path )
1062+
1063+ # Get the AIU model w/ the persistent model fixture
1064+ model = persistent_model .get_or_create (
1065+ is_gptq , is_fp8 , ** gptq_kwargs_aiu , ** model_kwargs
1066+ )
1067+
1068+ validation_model = _get_cpu_model (
1069+ is_gptq ,
1070+ is_fp8 ,
1071+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
1072+ ** gptq_kwargs_cpu ,
1073+ ** model_kwargs ,
1074+ )
1075+
1076+ _run_cpu_aiu_validation_test (
1077+ model_path ,
1078+ batch_size ,
1079+ seq_length ,
1080+ max_new_tokens ,
1081+ validation_model ,
1082+ model ,
1083+ micro_model_path ,
1084+ record_property ,
1085+ verify_cache_state = verify_cache_hit ,
1086+ )
0 commit comments