Skip to content

Commit 194b532

Browse files
Add sewing kit and utilities used for pruning scoring - pruning scoring is self-contained now (#584)
## What does this PR do? Add sewing kit and utilities used for pruning scoring - pruning scoring is self-contained now - no dependency on internal Nvidia code. --------- Signed-off-by: Daniel Korzekwa <dkorzekwa@nvidia.com> Signed-off-by: Daniel Korzekwa <daniel.korzekwa@gmail.com> Co-authored-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent f10be0d commit 194b532

File tree

21 files changed

+4605
-14
lines changed

21 files changed

+4605
-14
lines changed

modelopt/torch/_compress/activation_scoring/score_pruning_activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import hydra
1919
import torch
2020
from omegaconf import DictConfig
21-
from utils.parsing import format_global_config
21+
from modelopt.torch._compress.utils.parsing import format_global_config
2222

2323
from modelopt.torch._compress.tools.hydra_utils import register_hydra_resolvers
2424
from modelopt.torch._compress.tools.logger import mprint

modelopt/torch/_compress/compress.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
import build_library_and_stats
2424
import mip_and_realize_models
2525
import pruning_ckpts
26-
import score_pruning_activations
26+
import modelopt.torch._compress.activation_scoring.score_pruning_activations as score_pruning_activations
2727
import scoring
2828
from omegaconf import DictConfig
29-
from puzzle_tools.runtime import IRuntime
29+
from modelopt.torch._compress.tools.runtime import IRuntime
3030

3131
from modelopt.torch._compress.tools.hydra_utils import initialize_hydra_config_for_dir
3232

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# mypy: ignore-errors
16+
from .core import (
17+
Needle,
18+
KnotException,
19+
LoopFoundException,
20+
InputsLoopFoundException,
21+
MultipleExternalNodesException,
22+
OnlyInternalNodesException,
23+
OutputsLoopFoundException,
24+
ExternalTarget,
25+
ModuleTarget,
26+
ConstantTarget,
27+
FunctionTarget,
28+
RemoteTarget,
29+
StitchedModule,
30+
StitchedModuleException,
31+
CantResolveNodeDependenciesException,
32+
StitchedModuleOutput,
33+
)
34+
from .passage import always_false_predicate, always_true_predicate, InputArgs
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
18+
logger = logging.getLogger("sewing_kit")
19+
logger.setLevel(logging.WARN)

0 commit comments

Comments
 (0)