2424)
2525import json
2626from aiu_fms_testing_utils .utils .aiu_setup import dprint , aiu_dist_setup
27+ import shutil
2728import os
2829
2930try :
@@ -786,6 +787,7 @@ def _run_cpu_aiu_validation_test(
786787 cpu_model ,
787788 aiu_model ,
788789 micro_model_path ,
790+ verify_cache_state = None ,
789791):
790792 # Get the tokenizer and AIU / CPU models to compare
791793 tokenizer = tokenizers .get_tokenizer (model_path )
@@ -811,6 +813,12 @@ def _run_cpu_aiu_validation_test(
811813 aiu_model ,
812814 )
813815
816+ # Used only for cache tests; this is a nonparametric closure that
817+ # should assert the cache for torch sendnn is in the correct state
818+ # for this test
819+ if verify_cache_state is not None :
820+ verify_cache_state ()
821+
814822 # if level 0 fails validation, validate level 1
815823 if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0 :
816824 if failed_validation_level_0 :
@@ -832,6 +840,87 @@ def _run_cpu_aiu_validation_test(
832840 )
833841
834842
843+ def _reset_cache_settings (purge_cache_dir ):
844+ os .environ ["TORCH_SENDNN_CACHE_ENABLE" ] = "1"
845+ os .environ ["COMPILATION_MODE" ] = "offline_decoder"
846+ cache_dir = os .environ ["TORCH_SENDNN_CACHE_DIR" ]
847+
848+ # Ensure we start in clean state
849+ if purge_cache_dir and os .path .isdir (cache_dir ):
850+ shutil .rmtree (cache_dir )
851+ os .mkdir (cache_dir )
852+
853+ from torch_sendnn .backends import cache
854+
855+ # Explicitly clear cache paths from the global torch sendnn graph;
856+ # TODO would be better to add a helper to explicitly do this in
857+ # torch sendnn
858+ cache .cache = {}
859+
860+
861+ @pytest .fixture
862+ def use_cached_model ():
863+ """Configures the tochsendnn cache and runs the AIU model prior to test execution;
864+ this is computationally expensive and should only be used in situations like testing
865+ cache hit correctness;
866+ """
867+ torch .manual_seed (42 )
868+ torch .set_grad_enabled (False )
869+ _reset_cache_settings (purge_cache_dir = True )
870+
871+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
872+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
873+
874+ def verify_cache_miss ():
875+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
876+ updated_cache_len = (
877+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
878+ )
879+ assert updated_cache_len == max_new_tokens , (
880+ "cache directory not populated on cache miss"
881+ )
882+
883+ dprint (
884+ f"Setting up cache [i.e., cache miss check] for model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
885+ )
886+
887+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
888+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
889+
890+ model = _get_aiu_model (
891+ model_path ,
892+ gptq_kwargs_aiu ,
893+ persistent_model_inst = None ,
894+ )
895+
896+ validation_model = _get_cpu_model (
897+ model_path ,
898+ gptq_kwargs_cpu ,
899+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
900+ )
901+
902+ _run_cpu_aiu_validation_test (
903+ model_path ,
904+ batch_size ,
905+ seq_length ,
906+ max_new_tokens ,
907+ validation_model ,
908+ model ,
909+ micro_model_path ,
910+ verify_cache_state = verify_cache_miss ,
911+ )
912+
913+
914+ def _get_cache_test_params ():
915+ # NOTE - currently we always use granite 3.3 for the cache test,
916+ # TODO make this configurable as tests are refactored
917+ model_path = GRANITE_3p3_8B_INSTRUCT
918+ batch_size = COMMON_BATCH_SIZES [0 ]
919+ seq_length = COMMON_SEQ_LENGTHS [0 ]
920+ max_new_tokens = COMMON_MAX_NEW_TOKENS [0 ]
921+ return [model_path , batch_size , seq_length , max_new_tokens ]
922+
923+
835924@pytest .mark .parametrize (
836925 "model_path,batch_size,seq_length,max_new_tokens" , common_shapes
837926)
@@ -870,3 +959,51 @@ def test_common_shapes(
870959 model ,
871960 micro_model_path ,
872961 )
962+
963+
964+ def test_cache (use_cached_model ):
965+ torch .manual_seed (42 )
966+ torch .set_grad_enabled (False )
967+ _reset_cache_settings (purge_cache_dir = False )
968+
969+ model_path , batch_size , seq_length , max_new_tokens = _get_cache_test_params ()
970+ micro_model_path = MICRO_MODEL_MAPPING .get (model_path , None )
971+
972+ def verify_cache_hit ():
973+ cache_dir = os .environ .get ("TORCH_SENDNN_CACHE_DIR" )
974+ updated_cache_len = (
975+ len (os .listdir (cache_dir )) if os .path .isdir (cache_dir ) else 0
976+ )
977+ assert updated_cache_len == max_new_tokens , (
978+ "cache miss occurred when hit was expected"
979+ )
980+
981+ dprint (
982+ f"testing: model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } , for cache hit"
983+ )
984+
985+ # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
986+ gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
987+
988+ model = _get_aiu_model (
989+ model_path ,
990+ gptq_kwargs_aiu ,
991+ persistent_model_inst = None ,
992+ )
993+
994+ validation_model = _get_cpu_model (
995+ model_path ,
996+ gptq_kwargs_cpu ,
997+ micro_model_state_dict = model .state_dict () if USE_MICRO_MODELS else None ,
998+ )
999+
1000+ _run_cpu_aiu_validation_test (
1001+ model_path ,
1002+ batch_size ,
1003+ seq_length ,
1004+ max_new_tokens ,
1005+ validation_model ,
1006+ model ,
1007+ micro_model_path ,
1008+ verify_cache_state = verify_cache_hit ,
1009+ )
0 commit comments