Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,43 @@

@contextmanager
def hf_load_state_dict_with_device(device: DeviceLikeType):
"""Patch HF load_state_dict to use provided device.
NOTE (lucaslie): this function is called by ``load_checkpoint_in_model``. We provide the device
map here as a patch instead of going through ``load_checkpoint_in_model``. This is because
otherwise ``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of
calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks in
``nn.Module.load_state_dict`` to correctly load the weights. By providing the device map here,
we can ensure that ``load_checkpoint_in_model`` will call ``nn.Module.load_state_dict``.
"""Patch HF loading utilities according to our needs.
Following patches are applied:
1. load_state_dict to use provided device. NOTE (lucaslie): this function is called by
``load_checkpoint_in_model``. We provide the device map here as a patch instead of going
through ``load_checkpoint_in_model``. This is because otherwise
``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of
calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks
in ``nn.Module.load_state_dict`` to correctly load the weights. By providing the device
map here, we can ensure that ``load_checkpoint_in_model`` will call
``nn.Module.load_state_dict``.
2. change logging level of logger to ERROR to avoid logging warnings from HF state_dict
loading for missing/unexpected keys (happens for MoE expert-sharded layers for example).
"""
# save the original load_state_dict method
original_load_state_dict = modeling.load_state_dict

# save the original logger level
original_logger_level = modeling.logger.level

# Define and apply the patched version
def load_state_dict_with_device(checkpoint_file, device_map=None):
return original_load_state_dict(checkpoint_file, device_map={"": device})

# Apply the patch
modeling.load_state_dict = load_state_dict_with_device

# Change the logger level to ERROR
modeling.logger.setLevel("ERROR")

try:
yield
finally:
# Restore the original method, even if an exception occurred
modeling.load_state_dict = original_load_state_dict
# Restore the original logger level
modeling.logger.setLevel(original_logger_level)


# TODO (lucaslie): continue working on the base class
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_autodeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ def get_default_kwargs(self):
"compile_backend": "torch-cudagraph",
"free_mem_ratio": 0.7,
"cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128],
"transforms": {
"detect_sharding": {
"sharding_source": ['factory', 'heuristic'],
"sharding_dims": ['ep', 'bmm'],
},
# NOTE: some accuracy benchmarks may require fp32 precision for mamba cache
# "insert_cached_ssm_attention": {
# "cache_config": {
# "mamba_dtype": "float32",
# },
# },
}
}

def get_default_sampling_params(self):
Expand Down