From 012892e3c56166ce28957059cd74f1bd487884f5 Mon Sep 17 00:00:00 2001 From: Mio_Nora Date: Sun, 22 Dec 2024 17:12:35 +0800 Subject: [PATCH 1/2] fix structured pruning mode --- pruning/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pruning/utils.py b/pruning/utils.py index af2d13f..f76848e 100644 --- a/pruning/utils.py +++ b/pruning/utils.py @@ -66,7 +66,7 @@ def get_filter_mask(model, rate, args): threshold = np.sort(importance)[int(len(importance) * rate)] #threshold = np.percentile(importance, rate) - filter_mask = np.greater(importance, threshold) + filter_mask = np.greater(importance_all, threshold) return filter_mask From aaa128982836daf4eb75264e2323f6bf31c4620f Mon Sep 17 00:00:00 2001 From: Mio_Nora Date: Sun, 22 Dec 2024 17:14:27 +0800 Subject: [PATCH 2/2] fix structured pruning mode --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 6e0eef5..ac29a56 100644 --- a/main.py +++ b/main.py @@ -244,7 +244,7 @@ def train(args, train_loader, epoch, model, criterion, optimizer, **kwargs): if (globals()['iterations']+1) % args.prune_freq==0 and (epoch+1) <= args.milestones[1]: target_sparsity = args.prune_rate - args.prune_rate * (1 - (globals()['iterations'])/(args.milestones[1] * len(train_loader)))**3 if args.prune_type == 'structured': - filter_mask = pruning.get_filter_mask(model, rate, args) + filter_mask = pruning.get_filter_mask(model, target_sparsity, args) pruning.filter_prune(model, filter_mask) elif args.prune_type == 'unstructured': threshold = pruning.get_weight_threshold(model, target_sparsity, args)