diff --git a/main.py b/main.py index 6e0eef5..ac29a56 100644 --- a/main.py +++ b/main.py @@ -244,7 +244,7 @@ def train(args, train_loader, epoch, model, criterion, optimizer, **kwargs): if (globals()['iterations']+1) % args.prune_freq==0 and (epoch+1) <= args.milestones[1]: target_sparsity = args.prune_rate - args.prune_rate * (1 - (globals()['iterations'])/(args.milestones[1] * len(train_loader)))**3 if args.prune_type == 'structured': - filter_mask = pruning.get_filter_mask(model, rate, args) + filter_mask = pruning.get_filter_mask(model, target_sparsity, args) pruning.filter_prune(model, filter_mask) elif args.prune_type == 'unstructured': threshold = pruning.get_weight_threshold(model, target_sparsity, args) diff --git a/pruning/utils.py b/pruning/utils.py index af2d13f..f76848e 100644 --- a/pruning/utils.py +++ b/pruning/utils.py @@ -66,7 +66,7 @@ def get_filter_mask(model, rate, args): threshold = np.sort(importance)[int(len(importance) * rate)] #threshold = np.percentile(importance, rate) - filter_mask = np.greater(importance, threshold) + filter_mask = np.greater(importance_all, threshold) return filter_mask