Skip to content
Open
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion tests/models/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from torch import distributed as dist
from torch.fx.experimental import _config as fx_config
from torch_sendnn.backends.sendnn_backend import _get_global_state
from torch_sendnn.utils.graph_cache import SpyreGraphCache

from aiu_fms_testing_utils.testing.validation import (
extract_validation_information,
Expand All @@ -29,6 +31,7 @@
from transformers import AutoTokenizer

from aiu_fms_testing_utils.utils.aiu_setup import dprint, aiu_dist_setup
import shutil
import os

try:
Expand Down Expand Up @@ -132,7 +135,7 @@
if USE_MICRO_MODELS:
VALIDATION_INFO_DIR = os.path.join(VALIDATION_INFO_DIR, "tiny_models")

# pass custom model path list for eg: EXPORT FMS_TEST_SHAPES_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
# pass custom model path list for eg: EXPORT FMS_TESTING_COMMON_MODEL_PATHS="/tmp/models/granite-3-8b-base,/tmp/models/granite-7b-base"
if isinstance(COMMON_MODEL_PATHS, str):
COMMON_MODEL_PATHS = COMMON_MODEL_PATHS.split(",")

Expand Down Expand Up @@ -593,6 +596,8 @@ def _get_device_validation_information(
token_iter,
ATTN_NAME,
)
if cpu_validation_info is not None:
return cpu_validation_info

if cpu_validation_info is not None:
return cpu_validation_info
Expand Down Expand Up @@ -830,6 +835,8 @@ def _run_cpu_aiu_validation_test(
aiu_model,
micro_model_path,
record_property,
verify_cache_state=None,
warmup_only=False,
):
# Get the tokenizer and AIU / CPU models to compare
tokenizer = AutoTokenizer.from_pretrained(model_path)
Expand All @@ -853,6 +860,16 @@ def _run_cpu_aiu_validation_test(
aiu_model, input_ids, max_new_tokens, COMPILE_DYNAMIC_SENDNN, **extra_kwargs
)

# Used only for cache tests; this is a nonparametric closure that
# should assert the cache for torch sendnn is in the correct state
# for this test
if verify_cache_state is not None:
verify_cache_state()

# For some tests, e.g., cache checks, we only need to run the warmup
if warmup_only:
return

# Run validation level 0
failed_validation_level_0, validation_zero_info = _run_validation_level_0(
model_path,
Expand Down Expand Up @@ -888,6 +905,86 @@ def _run_cpu_aiu_validation_test(
)


def _reset_cache_settings(purge_cache_dir, cache_dir=None):
os.environ["TORCH_SENDNN_CACHE_ENABLE"] = "1"
os.environ["COMPILATION_MODE"] = "offline_decoder"
if cache_dir is not None:
# Might be a posixpath
cache_dir = str(cache_dir)
os.environ["TORCH_SENDNN_CACHE_DIR"] = cache_dir

# Ensure we start in clean state
if purge_cache_dir and os.path.isdir(cache_dir):
shutil.rmtree(cache_dir)
os.mkdir(cache_dir)

# NOTE: currently, the cache dir is pulled from
# TORCH_SENDNN_CACHE_DIR at initialization time,
# so this should correctly use the cache_dir
_get_global_state().use_aiu_cache = True
_get_global_state().spyre_graph_cache = SpyreGraphCache()


@pytest.fixture
def use_cached_model(request, persistent_model, record_property, tmp_path):
"""Configures the torchsendnn cache and runs the AIU model (warmup)
prior to test execution; this is computationally expensive and should
only be used in situations like testing cache hit correctness.
"""
torch.manual_seed(42)
torch.set_grad_enabled(False)
_reset_cache_settings(purge_cache_dir=True, cache_dir=tmp_path)

model_path, batch_size, seq_length, max_new_tokens = request.param
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)

def verify_cache_miss():
cache_dir = str(tmp_path)
updated_cache_len = (
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
)
assert updated_cache_len == max_new_tokens, (
"cache directory not populated on cache miss"
)

dprint(
f"Setting up cache [i.e., cache miss check] for model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
)

# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
is_gptq = len(gptq_kwargs_aiu) != 0
is_fp8 = "fp8" in ATTN_NAME
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)

# Get the AIU model w/ the persistent model fixture
model = persistent_model.get_or_create(
is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs
)

validation_model = _get_cpu_model(
is_gptq,
is_fp8,
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
**gptq_kwargs_cpu,
**model_kwargs,
)

_run_cpu_aiu_validation_test(
model_path,
batch_size,
seq_length,
max_new_tokens,
validation_model,
model,
micro_model_path,
record_property,
verify_cache_state=verify_cache_miss,
warmup_only=True,
)
return request.param


@pytest.mark.parametrize(
"model_path,batch_size,seq_length,max_new_tokens", COMMON_SHAPES
)
Expand Down Expand Up @@ -937,3 +1034,64 @@ def test_common_shapes(
micro_model_path,
record_property,
)


@pytest.mark.parametrize(
"use_cached_model",
COMMON_SHAPES,
indirect=True,
)
def test_cache(use_cached_model, persistent_model, record_property, tmp_path):
torch.manual_seed(42)
torch.set_grad_enabled(False)
_reset_cache_settings(purge_cache_dir=False, cache_dir=tmp_path)

# use_cached_model is an indirectly parametrized fixture, and the returned
# value is an expanded tuple from COMMON_SHAPES, so we unpack it here
model_path, batch_size, seq_length, max_new_tokens = use_cached_model
micro_model_path = MICRO_MODEL_MAPPING.get(model_path, None)

def verify_cache_hit():
cache_dir = str(tmp_path)
updated_cache_len = (
len(os.listdir(cache_dir)) if os.path.isdir(cache_dir) else 0
)
assert updated_cache_len == max_new_tokens, (
"cache miss occurred when hit was expected"
)

dprint(
f"testing: model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}, for cache hit"
)

# we don't currently support inferring gptq from get_model, so we must use an adapter with hf_configured
gptq_kwargs_aiu, gptq_kwargs_cpu = __maybe_get_gptq_kwargs(model_path)
is_gptq = len(gptq_kwargs_aiu) != 0
is_fp8 = "fp8" in ATTN_NAME
model_kwargs = _get_common_model_kwargs(is_gptq, model_path)

# Get the AIU model w/ the persistent model fixture
model = persistent_model.get_or_create(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we are re-creating the model and validation_model when it's already being created in the fixture. Is this required?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope! Good point, returned them both out of the fixture and deleted it from the cache hit check so that it'll be reused

is_gptq, is_fp8, **gptq_kwargs_aiu, **model_kwargs
)

validation_model = _get_cpu_model(
is_gptq,
is_fp8,
micro_model_state_dict=model.state_dict() if USE_MICRO_MODELS else None,
**gptq_kwargs_cpu,
**model_kwargs,
)

_run_cpu_aiu_validation_test(
model_path,
batch_size,
seq_length,
max_new_tokens,
validation_model,
model,
micro_model_path,
record_property,
verify_cache_state=verify_cache_hit,
warmup_only=True,
)