Skip to content

Commit b867afe

Browse files
lucasliesuyoggupta
authored andcommitted
[None][fix] AutoDeploy: update nano3 accuracy test (NVIDIA#9061)
Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com>
1 parent 9f0cc6c commit b867afe

File tree

2 files changed

+33
-8
lines changed

2 files changed

+33
-8
lines changed

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,43 @@
4949

5050
@contextmanager
5151
def hf_load_state_dict_with_device(device: DeviceLikeType):
52-
"""Patch HF load_state_dict to use provided device.
53-
54-
NOTE (lucaslie): this function is called by ``load_checkpoint_in_model``. We provide the device
55-
map here as a patch instead of going through ``load_checkpoint_in_model``. This is because
56-
otherwise ``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of
57-
calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks in
58-
``nn.Module.load_state_dict`` to correctly load the weights. By providing the device map here,
59-
we can ensure that ``load_checkpoint_in_model`` will call ``nn.Module.load_state_dict``.
52+
"""Patch HF loading utilities according to our needs.
53+
54+
Following patches are applied:
55+
1. load_state_dict to use provided device. NOTE (lucaslie): this function is called by
56+
``load_checkpoint_in_model``. We provide the device map here as a patch instead of going
57+
through ``load_checkpoint_in_model``. This is because otherwise
58+
``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of
59+
calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks
60+
in ``nn.Module.load_state_dict`` to correctly load the weights. By providing the device
61+
map here, we can ensure that ``load_checkpoint_in_model`` will call
62+
``nn.Module.load_state_dict``.
63+
2. change logging level of logger to ERROR to avoid logging warnings from HF state_dict
64+
loading for missing/unexpected keys (happens for MoE expert-sharded layers for example).
6065
"""
6166
# save the original load_state_dict method
6267
original_load_state_dict = modeling.load_state_dict
6368

69+
# save the original logger level
70+
original_logger_level = modeling.logger.level
71+
6472
# Define and apply the patched version
6573
def load_state_dict_with_device(checkpoint_file, device_map=None):
6674
return original_load_state_dict(checkpoint_file, device_map={"": device})
6775

6876
# Apply the patch
6977
modeling.load_state_dict = load_state_dict_with_device
7078

79+
# Change the logger level to ERROR
80+
modeling.logger.setLevel("ERROR")
81+
7182
try:
7283
yield
7384
finally:
7485
# Restore the original method, even if an exception occurred
7586
modeling.load_state_dict = original_load_state_dict
87+
# Restore the original logger level
88+
modeling.logger.setLevel(original_logger_level)
7689

7790

7891
# TODO (lucaslie): continue working on the base class

tests/integration/defs/accuracy/test_llm_api_autodeploy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,18 @@ def get_default_kwargs(self):
173173
"compile_backend": "torch-cudagraph",
174174
"free_mem_ratio": 0.7,
175175
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
176+
"transforms": {
177+
"detect_sharding": {
178+
"sharding_source": ['factory', 'heuristic'],
179+
"sharding_dims": ['ep', 'bmm'],
180+
},
181+
# NOTE: some accuracy benchmarks may require fp32 precision for mamba cache
182+
# "insert_cached_ssm_attention": {
183+
# "cache_config": {
184+
# "mamba_dtype": "float32",
185+
# },
186+
# },
187+
}
176188
}
177189

178190
def get_default_sampling_params(self):

0 commit comments

Comments
 (0)