Skip to content

Commit 06cd251

Browse files
avery-blanchardalex-jw-brooks
authored andcommitted
Update cache test, add validation for cached run
1 parent b774201 commit 06cd251

File tree

1 file changed

+240
-25
lines changed

1 file changed

+240
-25
lines changed

tests/models/test_decoders.py

Lines changed: 240 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
)
2525
import json
2626
from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
27-
27+
import shutil
2828
import os
2929

3030
try:
@@ -175,7 +175,6 @@
175175
)
176176
os.environ["VLLM_DT_MAX_BATCH_SIZE"] = str(max(max(common_batch_sizes), 2))
177177

178-
cache_params = list(itertools.product([common_model_paths[0]], [common_batch_sizes[0]], [common_seq_lengths[0]], [common_max_new_tokens[0]], ["miss", "hit"]))
179178

180179
# thresholds are chosen based on 1024 tokens per sequence
181180
# 1% error threshold rate between cpu fp32 and cuda fp16
@@ -676,56 +675,272 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
676675
else:
677676
print("passed validation level 0")
678677

679-
@pytest.mark.parametrize("model_path,batch_size,seq_length,max_new_tokens,cache_status", cache_params)
680-
def test_cache(model_path, batch_size, seq_length, max_new_tokens, cache_status):
678+
@pytest.mark.parametrize("cache_status", ["miss", "hit"])
679+
def test_cache(cache_status):
681680
torch.manual_seed(42)
681+
torch.set_grad_enabled(False)
682682
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
683+
os.environ["TORCH_SENDNN_CACHE_DIR"] = os.getcwd()+"/.cache"
683684
os.environ["COMPILATION_MODE"] = "offline_decoder"
684685

686+
if cache_status == "miss" and os.path.isdir(os.getcwd()+"/.cache"):
687+
# Remove cache from previous runs
688+
shutil.rmtree(os.getcwd()+"/.cache")
689+
690+
model_path = "ibm-granite/granite-3.3-8b-instruct"
691+
batch_size = common_batch_sizes[0]
692+
seq_length = common_seq_lengths[0]
693+
max_new_tokens = common_max_new_tokens[0]
694+
685695
dprint(f"testing with cache: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, cache={cache_status}")
686696

687-
if USE_MICRO_MODELS:
697+
# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
698+
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
699+
is_gptq = len(gptq_kwargs_aiu) != 0
700+
701+
micro_model_path = micro_model_mapping.get(model_path, None)
702+
if USE_MICRO_MODELS and micro_model_path is None:
703+
dprint("using randomly initialized model")
688704
micro_model_kwargs = {"architecture": "hf_configured", "nlayers": 3}
689705
else:
690-
micro_model_kwargs = {"architecture": "hf_pretrained"}
691-
706+
dprint("using trained model")
707+
micro_model_kwargs = {"architecture": "hf_pretrained"}
708+
692709
if not USE_MICRO_MODELS and os.path.exists(model_path):
693710
model_path_kwargs = {"model_path": model_path}
711+
elif USE_MICRO_MODELS and micro_model_path is not None:
712+
model_path_kwargs = {"model_path": micro_model_path}
694713
else:
695714
model_path_kwargs = {"variant": model_path}
696-
715+
697716
distributed_kwargs = {}
698717
if USE_DISTRIBUTED:
699-
distributed_kwargs["distr_param"] = "tp"
718+
distributed_kwargs["distributed_strategy"] = "tp"
700719
distributed_kwargs["group"] = dist.group.WORLD
701-
get_model_kwargs = {**model_path_kwargs, **micro_model_kwargs, **distributed_kwargs}
720+
721+
get_model_kwargs = {}
722+
if not is_gptq:
723+
get_model_kwargs = {
724+
**model_path_kwargs,
725+
**micro_model_kwargs,
726+
**distributed_kwargs,
727+
}
702728

703729
tokenizer = tokenizers.get_tokenizer(model_path)
704730

705731
# prepare the AIU model
706732
model = get_model(
733+
device_type="cpu",
734+
data_type=None if is_gptq else torch.float16,
735+
fused_weights=False,
736+
**get_model_kwargs,
737+
)
738+
739+
model.eval()
740+
model.compile(backend="sendnn")
741+
742+
# prepare the cpu model
743+
validation_model = get_model(
707744
device_type="cpu",
745+
data_type=None if is_gptq else torch.float32,
708746
fused_weights=False,
709-
**get_model_kwargs
747+
**gptq_kwargs_cpu,
748+
**get_model_kwargs,
710749
)
711750

712-
model.eval()
713-
torch.set_grad_enabled(False)
714-
model.compile(backend="sendnn_decoder")
715-
751+
if USE_MICRO_MODELS:
752+
serialization.load_state_dict_into_model(
753+
validation_model, model.state_dict(), **__custom_adapter
754+
)
716755

717756
# prepare input_ids
718-
input_ids, padding_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
757+
input_ids, extra_kwargs = __prepare_inputs(batch_size, seq_length, tokenizer)
758+
extra_kwargs["attn_name"] = ATTN_NAME
719759

720760
# warmup aiu model
721-
warmup_model(model, input_ids, max_new_tokens, **padding_kwargs)
761+
warmup_model(model, input_ids, max_new_tokens, compile_dynamic_sendnn, **extra_kwargs)
762+
763+
# generate cpu validation info
764+
cpu_validation_info = __load_validation_info(
765+
model_path, batch_size, seq_length, max_new_tokens, tokenizer, 0
766+
)
767+
if cpu_validation_info is None:
768+
cpu_validation_info = extract_validation_information(
769+
validation_model,
770+
input_ids,
771+
max_new_tokens,
772+
LogitsExtractorHook(),
773+
attn_algorithm="math",
774+
**extra_kwargs,
775+
)
776+
777+
if save_validation_info_outputs:
778+
cpu_validation_info.save(
779+
__get_validation_info_full_path(
780+
model_path, batch_size, seq_length, max_new_tokens, 0
781+
)
782+
)
783+
cpu_static_tokens = cpu_validation_info.get_info("tokens")
784+
eos_indexes = __find_eos_index(
785+
cpu_static_tokens, tokenizer.eos_token_id, seq_length, max_new_tokens
786+
)
787+
dprint(
788+
"cpu validation info extracted for validation level 0 and validation level 1 (iter=0)"
789+
)
722790

723-
# aiu validatation
791+
# first test validation level 0
724792
aiu_validation_info = extract_validation_information(
725-
model,
726-
input_ids,
727-
max_new_tokens,
728-
None,
729-
only_last_token=True,
730-
**padding_kwargs
731-
)
793+
model, input_ids, max_new_tokens, None, only_last_token="paged" not in ATTN_NAME, **extra_kwargs
794+
)
795+
dprint("aiu validation info extracted for validation level 0")
796+
797+
# check cache status before validating cached results
798+
updated_cache_len = len(os.listdir(os.getcwd()+"/.cache")) if os.path.isdir(os.getcwd()+"/.cache") else 0
799+
if cache_status == "miss":
800+
assert updated_cache_len == max_new_tokens, (
801+
"cache directory not populated on cache miss"
802+
)
803+
return
804+
else:
805+
assert updated_cache_len == max_new_tokens, (
806+
"cache miss occurred when hit was expected"
807+
)
808+
809+
# validate level 0
810+
failed_responses = validate_level_0(
811+
aiu_validation_info.get_info("tokens"), cpu_static_tokens
812+
)
813+
814+
failed_validation_level_0 = len(failed_responses) != 0
815+
816+
# if level 0 fails validation, validate level 1
817+
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
818+
819+
if failed_validation_level_0:
820+
dprint("failed validation level 0, testing validation level 1")
821+
else:
822+
dprint("passed validation level 0, testing validation level 1")
823+
824+
# metric calculator based on the cross-entropy and mean diff for each decode step
825+
def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
826+
cross_entropy = torch.nn.CrossEntropyLoss()(
827+
r, t.softmax(dim=1).to(dtype=torch.float32)
828+
)
829+
diff = torch.mean(
830+
torch.abs(
831+
r.softmax(dim=1).to(dtype=torch.float32)
832+
- t.softmax(dim=1).to(dtype=torch.float32)
833+
)
834+
)
835+
return (cross_entropy, diff)
836+
837+
iters = 1024 // max_new_tokens
838+
ce_fail_responses_list = []
839+
diff_fail_responses_list = []
840+
total_tokens = 0
841+
for i in range(iters):
842+
# for iteration 0, we have computed the cpu validation info in the prior step for seed=0, so skip
843+
if i != 0:
844+
input_ids, extra_kwargs = __prepare_inputs(
845+
batch_size, seq_length, tokenizer, seed=i
846+
)
847+
extra_kwargs["attn_name"] = ATTN_NAME
848+
cpu_validation_info = __load_validation_info(
849+
model_path, batch_size, seq_length, max_new_tokens, tokenizer, i
850+
)
851+
if cpu_validation_info is None:
852+
cpu_validation_info = extract_validation_information(
853+
validation_model,
854+
input_ids,
855+
max_new_tokens,
856+
LogitsExtractorHook(),
857+
attn_algorithm="math",
858+
**extra_kwargs,
859+
)
860+
dprint(
861+
f"cpu validation info extracted for validation level 1 - iter={i}"
862+
)
863+
if save_validation_info_outputs:
864+
cpu_validation_info.save(
865+
__get_validation_info_full_path(
866+
model_path, batch_size, seq_length, max_new_tokens, i
867+
)
868+
)
869+
cpu_static_tokens = cpu_validation_info.get_info("tokens")
870+
eos_indexes = __find_eos_index(
871+
cpu_static_tokens,
872+
tokenizer.eos_token_id,
873+
seq_length,
874+
max_new_tokens,
875+
)
876+
877+
# generate aiu validation info
878+
aiu_validation_info = extract_validation_information(
879+
model,
880+
input_ids,
881+
max_new_tokens,
882+
GoldenTokenHook(cpu_static_tokens),
883+
only_last_token=ATTN_TYPE != "paged",
884+
**extra_kwargs,
885+
)
886+
dprint(f"aiu validation info extracted for validation level 1 - iter={i}")
887+
if save_validation_info_outputs:
888+
aiu_validation_info.save(
889+
__get_validation_info_full_path(
890+
model_path, batch_size, seq_length, max_new_tokens, i, "aiu"
891+
)
892+
)
893+
894+
# capture all level 1 metrics
895+
level_1_metrics = capture_level_1_metrics(
896+
cpu_validation_info.get_info("logits"),
897+
aiu_validation_info.get_info("logits"),
898+
top_k_loss_calculator(20, _metric_calculator),
899+
)
900+
# only consider those metrics captured prior to the eos
901+
level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes)
902+
903+
# if we do not have real model weights, use a default_metrics_threshold
904+
if USE_MICRO_MODELS and micro_model_path is None:
905+
ce_threshold, diff_threshold = default_metrics_threshold
906+
# if we have real weights, try and get the proper validation metrics threshold
907+
else:
908+
# if we have a micro model with real weights, but no real thresholds, default to the full model thresholds
909+
if USE_MICRO_MODELS:
910+
ce_threshold, diff_threshold = fail_thresholds.get(
911+
(model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold)
912+
)
913+
else:
914+
ce_threshold, diff_threshold = fail_thresholds.get(
915+
(model_path, False), default_metrics_threshold
916+
)
917+
918+
# get all failed responses for each metric
919+
ce_fail_responses = filter_failed_level_1_cases(
920+
level_1_metrics, lambda m: m[0] >= ce_threshold
921+
)
922+
diff_fail_responses = filter_failed_level_1_cases(
923+
level_1_metrics,
924+
lambda m: m[1] >= diff_threshold,
925+
)
926+
927+
ce_fail_responses_list.extend(ce_fail_responses)
928+
diff_fail_responses_list.extend(diff_fail_responses)
929+
total_tokens += len(level_1_metrics)
930+
931+
# test the failure rates for across all tokens
932+
diff_failure_rate = len(diff_fail_responses_list) / total_tokens
933+
ce_failure_rate = len(ce_fail_responses_list) / total_tokens
934+
dprint(f"mean diff failure rate: {diff_failure_rate}")
935+
dprint(f"cross entropy loss failure rate: {ce_failure_rate}")
936+
if "mean_diff" not in skip_assertions:
937+
assert diff_failure_rate < failure_rate_threshold, (
938+
f"failure rate for mean diff was too high: {diff_failure_rate}"
939+
)
940+
if "ce" not in skip_assertions:
941+
assert ce_failure_rate < failure_rate_threshold, (
942+
f"failure rate for cross entropy loss was too high: {ce_failure_rate}"
943+
)
944+
print("passed validation level 1")
945+
else:
946+
print("passed validation level 0")

0 commit comments

Comments
 (0)