From a24754a1deb863bb1bf0a34e727b9b0d96137bb2 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Mon, 10 Nov 2025 20:57:52 -0800 Subject: [PATCH] update nano3 accuracy test Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/models/hf.py | 29 ++++++++++++++----- .../defs/accuracy/test_llm_api_autodeploy.py | 12 ++++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index 41b5c90214c..2fd228c4e11 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -49,18 +49,26 @@ @contextmanager def hf_load_state_dict_with_device(device: DeviceLikeType): - """Patch HF load_state_dict to use provided device. - - NOTE (lucaslie): this function is called by ``load_checkpoint_in_model``. We provide the device - map here as a patch instead of going through ``load_checkpoint_in_model``. This is because - otherwise ``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of - calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks in - ``nn.Module.load_state_dict`` to correctly load the weights. By providing the device map here, - we can ensure that ``load_checkpoint_in_model`` will call ``nn.Module.load_state_dict``. + """Patch HF loading utilities according to our needs. + + Following patches are applied: + 1. load_state_dict to use provided device. NOTE (lucaslie): this function is called by + ``load_checkpoint_in_model``. We provide the device map here as a patch instead of going + through ``load_checkpoint_in_model``. This is because otherwise + ``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of + calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks + in ``nn.Module.load_state_dict`` to correctly load the weights. By providing the device + map here, we can ensure that ``load_checkpoint_in_model`` will call + ``nn.Module.load_state_dict``. + 2. change logging level of logger to ERROR to avoid logging warnings from HF state_dict + loading for missing/unexpected keys (happens for MoE expert-sharded layers for example). """ # save the original load_state_dict method original_load_state_dict = modeling.load_state_dict + # save the original logger level + original_logger_level = modeling.logger.level + # Define and apply the patched version def load_state_dict_with_device(checkpoint_file, device_map=None): return original_load_state_dict(checkpoint_file, device_map={"": device}) @@ -68,11 +76,16 @@ def load_state_dict_with_device(checkpoint_file, device_map=None): # Apply the patch modeling.load_state_dict = load_state_dict_with_device + # Change the logger level to ERROR + modeling.logger.setLevel("ERROR") + try: yield finally: # Restore the original method, even if an exception occurred modeling.load_state_dict = original_load_state_dict + # Restore the original logger level + modeling.logger.setLevel(original_logger_level) # TODO (lucaslie): continue working on the base class diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 0d8649ac920..8b5e0e7d4ca 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -173,6 +173,18 @@ def get_default_kwargs(self): "compile_backend": "torch-cudagraph", "free_mem_ratio": 0.7, "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], + "transforms": { + "detect_sharding": { + "sharding_source": ['factory', 'heuristic'], + "sharding_dims": ['ep', 'bmm'], + }, + # NOTE: some accuracy benchmarks may require fp32 precision for mamba cache + # "insert_cached_ssm_attention": { + # "cache_config": { + # "mamba_dtype": "float32", + # }, + # }, + } } def get_default_sampling_params(self):