Skip to content

Commit 592a499

Browse files
authored
[Bug Fix] Add metadata as input for sharded_state_dict to follow latest Megatron code change (NVIDIA#606)
## What does this PR do? **Type of change:** ? <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> Bug fix **Overview:**? Add metadata as input for sharded_state_dict to follow latest Megatron Core [code change](NVIDIA/Megatron-LM@a2a1c89) ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> under Megatron-Bridge repo, the command below could run smoothly ``` pytest tests/functional_tests/quantization/test_quantization_workflow.py -v -s ``` Before the fix, error is like [this](https://github.com/NVIDIA-NeMo/Megatron-Bridge/actions/runs/19511909221/job/55906907649#step:3:7038) ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> --------- Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent f06c3f9 commit 592a499

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,30 @@ def _register_extra_state_callbacks(model: torch.nn.Module):
230230
CUSTOM_MODEL_PLUGINS.add(megatron_replace_quant_module_hook)
231231

232232

233+
def ensure_metadata_has_dp_cp_group(metadata):
234+
"""Ensure `metadata` is a dict containing `dp_cp_group` entry.
235+
236+
If `metadata` is None, a new dict is returned with `dp_cp_group` set.
237+
If `metadata` is a dict and missing `dp_cp_group`, it is updated in-place.
238+
239+
This function is adapted from megatron-lm's megatron.core.transformer.utils to avoid
240+
dependency on megatron-lm's specific version.
241+
242+
Note:
243+
This is a temporary method and will be removed once this function is merged to
244+
megatron.core.transformer.utils in the main branch of megatron-lm.
245+
"""
246+
if metadata is None:
247+
metadata = {}
248+
if "dp_cp_group" not in metadata:
249+
try:
250+
metadata["dp_cp_group"] = get_data_parallel_group(with_context_parallel=True)
251+
except (AssertionError, RuntimeError):
252+
# Fallback if context parallel is not initialized
253+
metadata["dp_cp_group"] = get_data_parallel_group()
254+
return metadata
255+
256+
233257
class _MegatronParallelLinear(_ParallelLinear):
234258
_functionals_to_replace = [
235259
(megatron_parallel, "linear_with_grad_accumulation_and_async_allreduce"),
@@ -285,6 +309,9 @@ def _parameter_to_keep_in_quantizer_state_dict(self, key):
285309
return False
286310

287311
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
312+
# Ensure metadata has dp_cp_group to avoid None subscript errors
313+
metadata = ensure_metadata_has_dp_cp_group(metadata)
314+
288315
# [WAR]: although we disable output_layer quantization by default but it will
289316
# still be picked up by mtq.quantize since it is a ColumnParallelLinear. We need
290317
# to further ensure that its sharded state_dict has no scalars or amax since
@@ -294,7 +321,7 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
294321
# state_dict mismatch.
295322
if prefix.endswith("output_layer."):
296323
# assert not any("_quantizer" in k for k in self.state_dict()), "quantized output_layer"
297-
return super().sharded_state_dict(prefix, sharded_offsets)
324+
return super().sharded_state_dict(prefix, sharded_offsets, metadata)
298325

299326
quantizer_state_dict = {}
300327
for k, v in self.state_dict(prefix="", keep_vars=True).items():
@@ -310,7 +337,7 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
310337
"Please use regular state_dict."
311338
)
312339
sharded_axis_dict = self._get_shard_axis_dict(quantizer_state_dict)
313-
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)
340+
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
314341
sharded_state_dict.update(
315342
**make_sharded_tensors_for_checkpoint(
316343
quantizer_state_dict, prefix, sharded_axis_dict, sharded_offsets

modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""Support sparsify and save/resore for Megatron."""
1717

1818
import megatron.core.transformer.mlp as megatron_mlp
19+
from megatron.core.parallel_state import get_data_parallel_group
1920
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
2021
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2122

@@ -25,12 +26,39 @@
2526
from ..module import SparseModule, SpDMRegistry
2627

2728

29+
def ensure_metadata_has_dp_cp_group(metadata):
30+
"""Ensure `metadata` is a dict containing `dp_cp_group` entry.
31+
32+
If `metadata` is None, a new dict is returned with `dp_cp_group` set.
33+
If `metadata` is a dict and missing `dp_cp_group`, it is updated in-place.
34+
35+
This function is adapted from megatron-lm's megatron.core.transformer.utils to avoid
36+
dependency on megatron-lm's specific version.
37+
38+
Note:
39+
This is a temporary method and will be removed once this function is merged to
40+
megatron.core.transformer.utils in the main branch of megatron-lm.
41+
"""
42+
if metadata is None:
43+
metadata = {}
44+
if "dp_cp_group" not in metadata:
45+
try:
46+
metadata["dp_cp_group"] = get_data_parallel_group(with_context_parallel=True)
47+
except (AssertionError, RuntimeError):
48+
# Fallback if context parallel is not initialized
49+
metadata["dp_cp_group"] = get_data_parallel_group()
50+
return metadata
51+
52+
2853
class _MegatronParallelLinear(SparseModule):
2954
def _get_shard_axis_dict(self, state_dict):
3055
raise NotImplementedError
3156

3257
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
33-
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)
58+
# Ensure metadata has dp_cp_group to avoid None subscript errors
59+
metadata = ensure_metadata_has_dp_cp_group(metadata)
60+
61+
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
3462

3563
sparse_state_dict = {
3664
k: v

0 commit comments

Comments
 (0)