Skip to content

Commit 1dde209

Browse files
Add score_pruning_activations (step 2/6) (#563)
## What does this PR do? - Add score_pruning_activations.py Notes: - validate_model.py still depends on Nvidia internal code (will be changed in the subsequent MR) - sharded_checkpoint_utils.py - for now it needs to use DeciLM from internal Nvidia code, to be changed in the next MR --------- Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com> Signed-off-by: Daniel Korzekwa <daniel.korzekwa@gmail.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent 07ca24d commit 1dde209

File tree

6 files changed

+568
-5
lines changed

6 files changed

+568
-5
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from pathlib import Path
17+
18+
import hydra
19+
import torch
20+
from omegaconf import DictConfig
21+
from utils.parsing import format_global_config
22+
23+
from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers
24+
from modelopt.torch._compress.tools.logger import mprint
25+
from modelopt.torch._compress.tools.runtime import BaseRuntime, NativeDdpRuntime
26+
from modelopt.torch._compress.tools.validate_model import validate_model
27+
from modelopt.torch._compress.utils.dist_utils import is_distributed
28+
29+
30+
def has_checkpoint_support(activation_hooks_kwargs: dict) -> bool:
31+
"""
32+
Determine if the activation hook method has proper checkpoint support implemented.
33+
34+
Args:
35+
activation_hooks_kwargs: Hook configuration
36+
37+
Returns:
38+
bool: True if the hook method has save_state/load_state implemented
39+
"""
40+
method = activation_hooks_kwargs.get("method", "")
41+
42+
# Methods with implemented checkpoint support
43+
supported_methods = {
44+
"iterative", # IterativeChannelContributionHook: save_state/load_state implemented
45+
"independent", # IndependentChannelContributionHook: save_state/load_state implemented
46+
"stats", # RouterStatsHook: save_state/load_state implemented
47+
"ranked_choice_voting", # RankedChoiceVotingHook: save_state/load_state implemented
48+
}
49+
50+
return method in supported_methods
51+
52+
53+
def check_scoring_completion(
54+
activations_log_dir: str, runtime, activation_hooks_kwargs=None
55+
) -> bool:
56+
"""
57+
Check if scoring is already completed by looking for the expected output files.
58+
Also checks if the scoring method is safe for resume.
59+
60+
Args:
61+
activations_log_dir: Directory where activation logs should be stored
62+
runtime: Runtime object for distributed processing
63+
activation_hooks_kwargs: Hook configuration to check if resume is safe
64+
65+
Returns:
66+
bool: True if scoring is completed (has rank files and args.json)
67+
"""
68+
# Only check completion on main process (or if no distributed runtime)
69+
if runtime is None or runtime.is_main_process:
70+
log_dir = Path(activations_log_dir)
71+
72+
# Check if directory exists
73+
if not log_dir.exists():
74+
return False
75+
76+
# Check for rank files (at least rank_0.pth should exist)
77+
rank_files = list(log_dir.glob("rank_*.pth"))
78+
79+
if not rank_files:
80+
return False
81+
82+
# Check for args.json (created by main process)
83+
args_file = log_dir / "args.json"
84+
has_args_json = args_file.exists()
85+
86+
# Check for completion: if we have rank files and args.json, scoring is complete
87+
if rank_files and has_args_json:
88+
# Add optional completion info for debugging
89+
mprint(f"Found completed scoring in {activations_log_dir}")
90+
mprint(f" - Found {len(rank_files)} rank files")
91+
mprint(f" - Found args.json: {has_args_json}")
92+
93+
return True
94+
95+
return False
96+
97+
98+
def should_skip_scoring_completely(cfg: DictConfig, runtime) -> bool:
99+
"""
100+
Determine if we should skip scoring entirely (only if 100% complete).
101+
Partial progress should proceed to validate_model for proper resume.
102+
103+
Args:
104+
cfg: Configuration object
105+
runtime: Runtime object for distributed processing
106+
107+
Returns:
108+
bool: True if we should skip scoring (100% completed), False if we should run/resume it
109+
"""
110+
# Check if activations_log_dir is specified
111+
if not hasattr(cfg.pruning, "activations_log_dir") or cfg.pruning.activations_log_dir is None:
112+
mprint("No activations_log_dir specified, running scoring")
113+
return False
114+
115+
# Check for force restart flag
116+
force_restart = getattr(cfg.pruning, "force_restart_scoring", False)
117+
if force_restart:
118+
mprint("Force restart flag set, will restart scoring regardless of existing artifacts")
119+
return False
120+
121+
# Get hook configuration to check if resume is mathematically safe
122+
activation_hooks_kwargs = getattr(cfg.pruning, "activation_hooks_kwargs", {})
123+
124+
# Check if scoring is already completed
125+
is_completed = check_scoring_completion(
126+
cfg.pruning.activations_log_dir, runtime, activation_hooks_kwargs
127+
)
128+
129+
# Broadcast the result to all processes in distributed mode
130+
if runtime is not None and runtime.world_size > 1:
131+
should_skip = [is_completed] # Use list for mutable object
132+
torch.distributed.broadcast_object_list(should_skip, src=0)
133+
is_completed = should_skip[0]
134+
135+
if is_completed:
136+
mprint("Scoring 100% completed, skipping...")
137+
138+
return is_completed
139+
140+
141+
# Old progress tracking removed - checkpoint manager handles all progress tracking
142+
143+
144+
def launch_score_activations(cfg: DictConfig, runtime):
145+
# Check if we should skip scoring entirely (only if 100% complete)
146+
if should_skip_scoring_completely(cfg, runtime):
147+
return
148+
149+
mprint("Starting pruning activation scoring...")
150+
151+
# The checkpoint manager inside validate_model handles all progress tracking
152+
validate_model(args=cfg.pruning, runtime=runtime)
153+
154+
155+
@hydra.main("", version_base="1.3")
156+
def main(cfg: DictConfig) -> None:
157+
cfg = hydra.utils.instantiate(cfg)
158+
mprint(format_global_config(cfg, title="Score Pruning Activations"))
159+
160+
_runtime = (
161+
NativeDdpRuntime(
162+
dtype=torch.bfloat16, torch_distributed_timeout=getattr(cfg, "nccl_timeout_minutes")
163+
)
164+
if is_distributed()
165+
else BaseRuntime(dtype=torch.bfloat16)
166+
)
167+
with _runtime as runtime:
168+
launch_score_activations(cfg, runtime)
169+
runtime.wait_for_everyone()
170+
171+
172+
if __name__ == "__main__":
173+
register_hydra_resolvers()
174+
main()

modelopt/torch/_compress/nas/plugins/compress_nas_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
import build_library_and_stats
2727
import mip_and_realize_models
2828
import pruning_ckpts
29-
import score_pruning_activations
3029
import scoring
3130
import torch
3231
from torch import nn
3332

33+
from modelopt.torch._compress.activation_scoring import score_pruning_activations
3434
from modelopt.torch._compress.decilm.converters.convert_llama3_to_decilm import (
3535
convert_llama3_to_decilm,
3636
)

modelopt/torch/_compress/tools/sharded_checkpoint_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,24 +29,24 @@
2929
import torch.distributed
3030
import torch.nn as nn
3131
from huggingface_hub import split_torch_state_dict_into_shards
32+
from puzzle_tools.deci_lm_hf_code.modeling_decilm import DeciLMForCausalLM
3233
from safetensors import safe_open
3334
from safetensors.torch import load_file as safe_load_file
3435
from safetensors.torch import save_file as safe_save_file
3536
from tqdm import tqdm
3637
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
3738
from transformers.utils.hub import cached_file, get_checkpoint_shard_files
3839
from typing_extensions import override
39-
from utils.utils import EmptyInitOnDevice
4040

4141
from modelopt.torch._compress.decilm.deci_lm_hf_code.configuration_decilm import DeciLMConfig
4242
from modelopt.torch._compress.decilm.deci_lm_hf_code.modeling_decilm import (
4343
DeciLMDecoderLayer,
44-
DeciLMForCausalLM,
4544
rope_type_to_class,
4645
)
4746
from modelopt.torch._compress.tools.checkpoint_utils import load_model_config, load_state_dict
4847
from modelopt.torch._compress.tools.logger import mprint
4948
from modelopt.torch._compress.tools.runtime import IRuntime
49+
from modelopt.torch._compress.utils.utils import EmptyInitOnDevice
5050

5151

5252
class DummyModule(nn.Module):
@@ -392,7 +392,7 @@ def load_sharded_state_dict(
392392
partial_state_dict.update(shard)
393393
else:
394394
with safe_open(safetensors_path, framework="pt", device=str(device)) as f:
395-
for key in f:
395+
for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable
396396
if key in keys_to_load:
397397
partial_state_dict[key] = f.get_tensor(key)
398398
return partial_state_dict
@@ -417,6 +417,6 @@ def load_state_dict_shapes(model_name_or_path: str | Path) -> dict[str, tuple]:
417417
state_dict_shapes = {}
418418
for safetensors_path in shard_paths:
419419
with safe_open(safetensors_path, framework="pt") as f:
420-
for key in f:
420+
for key in f.keys(): # noqa: SIM118 - safe_open objects require .keys(), not directly iterable
421421
state_dict_shapes[key] = tuple(f.get_tensor(key).shape)
422422
return state_dict_shapes

0 commit comments

Comments
 (0)