|
24 | 24 | ) |
25 | 25 | import json |
26 | 26 | from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup |
27 | | - |
| 27 | +import shutil |
28 | 28 | import os |
29 | 29 |
|
30 | 30 | try: |
|
175 | 175 | ) |
176 | 176 | os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2)) |
177 | 177 |
|
178 | | -cache_params = list(itertools.product([common_model_paths[0]], [common_batch_sizes[0]], [common_seq_lengths[0]], [common_max_new_tokens[0]], ["miss", "hit"])) |
179 | 178 |
|
180 | 179 | # thresholds are chosen based on 1024 tokens per sequence |
181 | 180 | # 1% error threshold rate between cpu fp32 and cuda fp16 |
@@ -676,56 +675,272 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
676 | 675 | else: |
677 | 676 | print("passed validation level 0") |
678 | 677 |
|
679 | | -@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens,cache_status", cache_params) |
680 | | -def test_cache(model_path, batch_size, seq_length, max_new_tokens, cache_status): |
| 678 | +@pytest.mark.parametrize("cache_status", ["miss", "hit"]) |
| 679 | +def test_cache(cache_status): |
681 | 680 | torch.manual_seed(42) |
| 681 | + torch.set_grad_enabled(False) |
682 | 682 | os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1" |
| 683 | + os.environ["TORCH_SENDNN_CACHE_DIR"] = os.getcwd()+"/.cache" |
683 | 684 | os.environ["COMPILATION_MODE"] = "offline_decoder" |
684 | 685 |
|
| 686 | + if cache_status == "miss" and os.path.isdir(os.getcwd()+"/.cache"): |
| 687 | + # Remove cache from previous runs |
| 688 | + shutil.rmtree(os.getcwd()+"/.cache") |
| 689 | + |
| 690 | + model_path = "ibm-granite/granite-3.3-8b-instruct" |
| 691 | + batch_size = common_batch_sizes[0] |
| 692 | + seq_length = common_seq_lengths[0] |
| 693 | + max_new_tokens = common_max_new_tokens[0] |
| 694 | + |
685 | 695 | dprint(f"testing with cache: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, cache={cache_status}") |
686 | 696 |
|
687 | | - if USE_MICRO_MODELS: |
| 697 | + # we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured |
| 698 | + gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path) |
| 699 | + is_gptq = len(gptq_kwargs_aiu) != 0 |
| 700 | + |
| 701 | + micro_model_path = micro_model_mapping.get(model_path, None) |
| 702 | + if USE_MICRO_MODELS and micro_model_path is None: |
| 703 | + dprint("using randomly initialized model") |
688 | 704 | micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3} |
689 | 705 | else: |
690 | | - micro_model_kwargs = {"architecture": "hf_pretrained"} |
691 | | - |
| 706 | + dprint("using trained model") |
| 707 | + micro_model_kwargs = {"architecture": "hf_pretrained"} |
| 708 | + |
692 | 709 | if not USE_MICRO_MODELS and os.path.exists(model_path): |
693 | 710 | model_path_kwargs = {"model_path": model_path} |
| 711 | + elif USE_MICRO_MODELS and micro_model_path is not None: |
| 712 | + model_path_kwargs = {"model_path": micro_model_path} |
694 | 713 | else: |
695 | 714 | model_path_kwargs = {"variant": model_path} |
696 | | - |
| 715 | + |
697 | 716 | distributed_kwargs = {} |
698 | 717 | if USE_DISTRIBUTED: |
699 | | - distributed_kwargs["distr_param"] = "tp" |
| 718 | + distributed_kwargs["distributed_strategy"] = "tp" |
700 | 719 | distributed_kwargs["group"] = dist.group.WORLD |
701 | | - get_model_kwargs = {**model_path_kwargs, **micro_model_kwargs, **distributed_kwargs} |
| 720 | + |
| 721 | + get_model_kwargs = {} |
| 722 | + if not is_gptq: |
| 723 | + get_model_kwargs = { |
| 724 | + **model_path_kwargs, |
| 725 | + **micro_model_kwargs, |
| 726 | + **distributed_kwargs, |
| 727 | + } |
702 | 728 |
|
703 | 729 | tokenizer = tokenizers.get_tokenizer(model_path) |
704 | 730 |
|
705 | 731 | # prepare the AIU model |
706 | 732 | model = get_model( |
| 733 | + device_type="cpu", |
| 734 | + data_type=None if is_gptq else torch.float16, |
| 735 | + fused_weights=False, |
| 736 | + **get_model_kwargs, |
| 737 | + ) |
| 738 | + |
| 739 | + model.eval() |
| 740 | + model.compile(backend="sendnn") |
| 741 | + |
| 742 | + # prepare the cpu model |
| 743 | + validation_model = get_model( |
707 | 744 | device_type="cpu", |
| 745 | + data_type=None if is_gptq else torch.float32, |
708 | 746 | fused_weights=False, |
709 | | - **get_model_kwargs |
| 747 | + **gptq_kwargs_cpu, |
| 748 | + **get_model_kwargs, |
710 | 749 | ) |
711 | 750 |
|
712 | | - model.eval() |
713 | | - torch.set_grad_enabled(False) |
714 | | - model.compile(backend="sendnn_decoder") |
715 | | - |
| 751 | + if USE_MICRO_MODELS: |
| 752 | + serialization.load_state_dict_into_model( |
| 753 | + validation_model, model.state_dict(), **__custom_adapter |
| 754 | + ) |
716 | 755 |
|
717 | 756 | # prepare input_ids |
718 | | - input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) |
| 757 | + input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer) |
| 758 | + extra_kwargs["attn_name"] = ATTN_NAME |
719 | 759 |
|
720 | 760 | # warmup aiu model |
721 | | - warmup_model(model, input_ids, max_new_tokens, **padding_kwargs) |
| 761 | + warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs) |
| 762 | + |
| 763 | + # generate cpu validation info |
| 764 | + cpu_validation_info = __load_validation_info( |
| 765 | + model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0 |
| 766 | + ) |
| 767 | + if cpu_validation_info is None: |
| 768 | + cpu_validation_info = extract_validation_information( |
| 769 | + validation_model, |
| 770 | + input_ids, |
| 771 | + max_new_tokens, |
| 772 | + LogitsExtractorHook(), |
| 773 | + attn_algorithm="math", |
| 774 | + **extra_kwargs, |
| 775 | + ) |
| 776 | + |
| 777 | + if save_validation_info_outputs: |
| 778 | + cpu_validation_info.save( |
| 779 | + __get_validation_info_full_path( |
| 780 | + model_path, batch_size, seq_length, max_new_tokens, 0 |
| 781 | + ) |
| 782 | + ) |
| 783 | + cpu_static_tokens = cpu_validation_info.get_info("tokens") |
| 784 | + eos_indexes = __find_eos_index( |
| 785 | + cpu_static_tokens, tokenizer.eos_token_id, seq_length, max_new_tokens |
| 786 | + ) |
| 787 | + dprint( |
| 788 | + "cpu validation info extracted for validation level 0 and validation level 1 (iter=0)" |
| 789 | + ) |
722 | 790 |
|
723 | | - # aiu validatation |
| 791 | + # first test validation level 0 |
724 | 792 | aiu_validation_info = extract_validation_information( |
725 | | - model, |
726 | | - input_ids, |
727 | | - max_new_tokens, |
728 | | - None, |
729 | | - only_last_token=True, |
730 | | - **padding_kwargs |
731 | | -) |
| 793 | + model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs |
| 794 | + ) |
| 795 | + dprint("aiu validation info extracted for validation level 0") |
| 796 | + |
| 797 | + # check cache status before validating cached results |
| 798 | + updated_cache_len = len(os.listdir(os.getcwd()+"/.cache")) if os.path.isdir(os.getcwd()+"/.cache") else 0 |
| 799 | + if cache_status == "miss": |
| 800 | + assert updated_cache_len == max_new_tokens, ( |
| 801 | + "cache directory not populated on cache miss" |
| 802 | + ) |
| 803 | + return |
| 804 | + else: |
| 805 | + assert updated_cache_len == max_new_tokens, ( |
| 806 | + "cache miss occurred when hit was expected" |
| 807 | + ) |
| 808 | + |
| 809 | + # validate level 0 |
| 810 | + failed_responses = validate_level_0( |
| 811 | + aiu_validation_info.get_info("tokens"), cpu_static_tokens |
| 812 | + ) |
| 813 | + |
| 814 | + failed_validation_level_0 = len(failed_responses) != 0 |
| 815 | + |
| 816 | + # if level 0 fails validation, validate level 1 |
| 817 | + if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: |
| 818 | + |
| 819 | + if failed_validation_level_0: |
| 820 | + dprint("failed validation level 0, testing validation level 1") |
| 821 | + else: |
| 822 | + dprint("passed validation level 0, testing validation level 1") |
| 823 | + |
| 824 | + # metric calculator based on the cross-entropy and mean diff for each decode step |
| 825 | + def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
| 826 | + cross_entropy = torch.nn.CrossEntropyLoss()( |
| 827 | + r, t.softmax(dim=1).to(dtype=torch.float32) |
| 828 | + ) |
| 829 | + diff = torch.mean( |
| 830 | + torch.abs( |
| 831 | + r.softmax(dim=1).to(dtype=torch.float32) |
| 832 | + - t.softmax(dim=1).to(dtype=torch.float32) |
| 833 | + ) |
| 834 | + ) |
| 835 | + return (cross_entropy, diff) |
| 836 | + |
| 837 | + iters = 1024 // max_new_tokens |
| 838 | + ce_fail_responses_list = [] |
| 839 | + diff_fail_responses_list = [] |
| 840 | + total_tokens = 0 |
| 841 | + for i in range(iters): |
| 842 | + # for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip |
| 843 | + if i != 0: |
| 844 | + input_ids, extra_kwargs = __prepare_inputs( |
| 845 | + batch_size, seq_length, tokenizer, seed=i |
| 846 | + ) |
| 847 | + extra_kwargs["attn_name"] = ATTN_NAME |
| 848 | + cpu_validation_info = __load_validation_info( |
| 849 | + model_path, batch_size, seq_length, max_new_tokens, tokenizer, i |
| 850 | + ) |
| 851 | + if cpu_validation_info is None: |
| 852 | + cpu_validation_info = extract_validation_information( |
| 853 | + validation_model, |
| 854 | + input_ids, |
| 855 | + max_new_tokens, |
| 856 | + LogitsExtractorHook(), |
| 857 | + attn_algorithm="math", |
| 858 | + **extra_kwargs, |
| 859 | + ) |
| 860 | + dprint( |
| 861 | + f"cpu validation info extracted for validation level 1 - iter={i}" |
| 862 | + ) |
| 863 | + if save_validation_info_outputs: |
| 864 | + cpu_validation_info.save( |
| 865 | + __get_validation_info_full_path( |
| 866 | + model_path, batch_size, seq_length, max_new_tokens, i |
| 867 | + ) |
| 868 | + ) |
| 869 | + cpu_static_tokens = cpu_validation_info.get_info("tokens") |
| 870 | + eos_indexes = __find_eos_index( |
| 871 | + cpu_static_tokens, |
| 872 | + tokenizer.eos_token_id, |
| 873 | + seq_length, |
| 874 | + max_new_tokens, |
| 875 | + ) |
| 876 | + |
| 877 | + # generate aiu validation info |
| 878 | + aiu_validation_info = extract_validation_information( |
| 879 | + model, |
| 880 | + input_ids, |
| 881 | + max_new_tokens, |
| 882 | + GoldenTokenHook(cpu_static_tokens), |
| 883 | + only_last_token=ATTN_TYPE != "paged", |
| 884 | + **extra_kwargs, |
| 885 | + ) |
| 886 | + dprint(f"aiu validation info extracted for validation level 1 - iter={i}") |
| 887 | + if save_validation_info_outputs: |
| 888 | + aiu_validation_info.save( |
| 889 | + __get_validation_info_full_path( |
| 890 | + model_path, batch_size, seq_length, max_new_tokens, i, "aiu" |
| 891 | + ) |
| 892 | + ) |
| 893 | + |
| 894 | + # capture all level 1 metrics |
| 895 | + level_1_metrics = capture_level_1_metrics( |
| 896 | + cpu_validation_info.get_info("logits"), |
| 897 | + aiu_validation_info.get_info("logits"), |
| 898 | + top_k_loss_calculator(20, _metric_calculator), |
| 899 | + ) |
| 900 | + # only consider those metrics captured prior to the eos |
| 901 | + level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes) |
| 902 | + |
| 903 | + # if we do not have real model weights, use a default_metrics_threshold |
| 904 | + if USE_MICRO_MODELS and micro_model_path is None: |
| 905 | + ce_threshold, diff_threshold = default_metrics_threshold |
| 906 | + # if we have real weights, try and get the proper validation metrics threshold |
| 907 | + else: |
| 908 | + # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds |
| 909 | + if USE_MICRO_MODELS: |
| 910 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 911 | + (model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold) |
| 912 | + ) |
| 913 | + else: |
| 914 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 915 | + (model_path, False), default_metrics_threshold |
| 916 | + ) |
| 917 | + |
| 918 | + # get all failed responses for each metric |
| 919 | + ce_fail_responses = filter_failed_level_1_cases( |
| 920 | + level_1_metrics, lambda m: m[0] >= ce_threshold |
| 921 | + ) |
| 922 | + diff_fail_responses = filter_failed_level_1_cases( |
| 923 | + level_1_metrics, |
| 924 | + lambda m: m[1] >= diff_threshold, |
| 925 | + ) |
| 926 | + |
| 927 | + ce_fail_responses_list.extend(ce_fail_responses) |
| 928 | + diff_fail_responses_list.extend(diff_fail_responses) |
| 929 | + total_tokens += len(level_1_metrics) |
| 930 | + |
| 931 | + # test the failure rates for across all tokens |
| 932 | + diff_failure_rate = len(diff_fail_responses_list) / total_tokens |
| 933 | + ce_failure_rate = len(ce_fail_responses_list) / total_tokens |
| 934 | + dprint(f"mean diff failure rate: {diff_failure_rate}") |
| 935 | + dprint(f"cross entropy loss failure rate: {ce_failure_rate}") |
| 936 | + if "mean_diff" not in skip_assertions: |
| 937 | + assert diff_failure_rate < failure_rate_threshold, ( |
| 938 | + f"failure rate for mean diff was too high: {diff_failure_rate}" |
| 939 | + ) |
| 940 | + if "ce" not in skip_assertions: |
| 941 | + assert ce_failure_rate < failure_rate_threshold, ( |
| 942 | + f"failure rate for cross entropy loss was too high: {ce_failure_rate}" |
| 943 | + ) |
| 944 | + print("passed validation level 1") |
| 945 | + else: |
| 946 | + print("passed validation level 0") |
0 commit comments