1+ from collections import defaultdict
2+ import logging
3+ import os
4+ import json
5+ import tqdm
6+ import argparse
17
2-
8+ from datasets import load_dataset
39import torch
10+ from transformers import AutoConfig , AutoTokenizer , AutoModelForCausalLM
11+ from transformers .trainer_utils import set_seed
412
513from src .activation_capture import ActivationCapture , Hook
614
15+ # Setup logging
16+ logging .basicConfig (level = logging .INFO )
17+ logger = logging .getLogger (__name__ )
18+
719def calculate_threshold_one_token (neuron_outputs , cett_target , n_quantiles = 1000 ):
820 norms = neuron_outputs .norm (dim = 0 )
921 quantiles = norms .quantile (torch .linspace (0 ,1 ,n_quantiles ))
@@ -31,7 +43,7 @@ def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=50
3143 model .activation_capture = model .ACTIVATION_CAPTURE (model )
3244 model .activation_capture .register_hooks (hooks = [Hook .UP ])
3345
34- thresholds = []
46+ thresholds = defaultdict ( list )
3547
3648 with torch .no_grad ():
3749 for batch in dataloader :
@@ -42,13 +54,190 @@ def find_threshold(model, dataloader, layer_idx, cett_target=0.2, n_quantiles=50
4254
4355 _ = model (input_ids = input_ids , attention_mask = attention_mask )
4456
45- activations = model .activation_capture .mlp_activations [Hook .UP ][layer_idx ]
46- activations = activations .view (- 1 , activations .size (- 1 ))
57+ for layer ,layer_idx in enumerate (model .activation_capture .get_layers ()):
58+ activations = model .activation_capture .mlp_activations [Hook .UP ][layer_idx ]
59+ activations = activations .view (- 1 , activations .size (- 1 ))
60+
61+ for i in range (activations .size (0 )):
62+ neuron_outputs = activations [i ] * layer .mlp .down_proj .weight
63+ threshold = calculate_threshold_one_token (neuron_outputs , cett_target = cett_target , n_quantiles = n_quantiles )
64+ thresholds [layer_idx ].append (threshold )
65+
66+ for layer_idx , layer_thresholds in thresholds .items ():
67+ thresholds [layer_idx ] = sum (layer_thresholds ) / len (layer_thresholds )
68+
69+ return thresholds
70+
71+
72+
73+ def find_thresholds (
74+ model_name ,
75+ dataset_name ,
76+ dataset_config ,
77+ max_samples ,
78+ cett_target ,
79+ n_quantiles ,
80+ save_path ,
81+ device ,
82+ ):
83+
84+ # Load tokenizer and model
85+ logger .info (f"Loading model: { model_name } " )
86+ tokenizer = AutoTokenizer .from_pretrained (model_name )
87+ tokenizer .pad_token = tokenizer .eos_token
88+
89+ model = AutoModelForCausalLM .from_pretrained (
90+ model_name ,
91+ torch_dtype = torch .float32 ,
92+ device_map = "auto" if device .type == "cuda" else None ,
93+ )
94+
95+ if device .type != "cuda" :
96+ model = model .to (device )
97+
98+ model .eval ()
99+ model .activation_capture = model .ACTIVATION_CAPTURE (model )
100+ model .activation_capture .register_hooks (hooks = [Hook .UP ])
101+
102+ # Load dataset
103+ logger .info (f"Loading dataset: { dataset_name } " )
104+ if dataset_config :
105+ dataset = load_dataset (
106+ dataset_name , dataset_config , split = "train" , streaming = True
107+ )
108+ else :
109+ dataset = load_dataset (dataset_name , split = "train" , streaming = True )
110+ dataset = dataset .shuffle (buffer_size = 10000 , seed = 42 )
111+
112+ def sample_and_tokenize (examples ):
113+ """Sample text chunks before tokenization for efficiency using vectorized operations."""
114+ texts = examples ["text" ]
115+ tokenized = tokenizer (texts , return_tensors = "pt" )
116+
117+ # Convert to lists
118+ return {
119+ "text" : texts ,
120+ "input_ids" : tokenized ["input_ids" ],
121+ }
122+
123+ # Tokenize
124+ dataset = dataset .take (max_samples ).map (sample_and_tokenize , batched = False )
125+ dataset = dataset .with_format ("torch" )
47126
48- for i in range (activations .size (0 )):
49- neuron_outputs = activations [i ] * model .model .layers [0 ].mlp .down_proj .weight
50- threshold = calculate_threshold_one_token (neuron_outputs , cett_target = cett_target , n_quantiles = n_quantiles )
51- thresholds .append (threshold )
127+ dataloader = TorchDataLoader (dataset , batch_size = 1 , num_workers = 8 , pin_memory = False , prefetch_factor = 2 ) # type: ignore
128+
129+ # Compute thresholds for each layer across all dataset entries
130+ logger .info (f"Beginning to compute thresholds using { max_samples } samples" )
131+ thresholds = defaultdict (list )
132+ with torch .no_grad ():
133+ for batch in tqdm .tqdm (dataloader ):
134+ input_ids = batch ["input_ids" ].to (device )
135+ attention_mask = batch ["attention_mask" ].to (device )
136+
137+ _ = model (input_ids = input_ids , attention_mask = attention_mask )
138+
139+ for layer ,layer_idx in enumerate (model .activation_capture .get_layers ()):
140+ activations = model .activation_capture .mlp_activations [Hook .UP ][layer_idx ]
141+ activations = activations .view (- 1 , activations .size (- 1 ))
142+
143+ for i in range (activations .size (0 )):
144+ neuron_outputs = activations [i ] * layer .mlp .down_proj .weight
145+ threshold = calculate_threshold_one_token (neuron_outputs , cett_target = cett_target , n_quantiles = n_quantiles )
146+ thresholds [layer_idx ].append (threshold )
147+
148+ model .activation_capture .clear_captures ()
149+ if device .type == "cuda" :
150+ torch .cuda .empty_cache ()
151+
152+ for layer_idx , layer_thresholds in thresholds .items ():
153+ thresholds [layer_idx ] = sum (layer_thresholds ) / len (layer_thresholds )
154+
155+ # Save layerwise thresholds as record in central json file
156+ if not os .path .exists (save_path ):
157+ with open ("save_path" , mode = "r" , encoding = "utf-8" ) as read_file :
158+ threshold_dict = json .load (read_file )
159+ else :
160+ threshold_dict = {}
161+ threshold_dict [model_name ] = thresholds
162+ with open ("save_path" , mode = "r" , encoding = "utf-8" ) as write_file :
163+ json .dump (threshold_dict , write_file )
164+
165+
166+
167+ def parse_args ():
168+ parser = argparse .ArgumentParser (
169+ description = "Generate training dataset for sparsity predictors"
170+ )
171+ parser .add_argument (
172+ "--model_name" ,
173+ type = str ,
174+ required = True ,
175+ help = "Name or path of the base model (e.g., meta-llama/Llama-2-7b-hf)" ,
176+ )
177+ parser .add_argument (
178+ "--dataset" ,
179+ type = str ,
180+ default = "allenai/c4" ,
181+ help = "Dataset name (default: allenai/c4)" ,
182+ )
183+ parser .add_argument (
184+ "--dataset_config" ,
185+ type = str ,
186+ default = "en" ,
187+ help = "Dataset configuration (e.g., en for C4)" ,
188+ )
189+ parser .add_argument (
190+ "--save_path" ,
191+ type = str ,
192+ default = "thresholds.json" ,
193+ help = "Path to json file for thresholds" ,
194+ )
195+ parser .add_argument (
196+ "--max_samples" ,
197+ type = int ,
198+ default = 500 ,
199+ help = "Maximum number of samples to process" ,
200+ )
201+ parser .add_argument (
202+ "--cett_target" ,
203+ type = float ,
204+ default = 0.2 ,
205+ help = "Optimal CETT value for threshold-finding" ,
206+ )
207+ parser .add_argument (
208+ "--n_quantiles" ,
209+ type = int ,
210+ default = 500 ,
211+ help = "Number of quantiles to sort neuron outputs into for threshold-finding" ,
212+ )
213+ parser .add_argument ("--seed" , type = int , default = 42 , help = "Random seed" )
214+ parser .add_argument (
215+ "--device" , type = str , default = "auto" , help = "Device to use (auto, cpu, cuda)"
216+ )
217+
218+ return parser .parse_args ()
219+
220+
221+ if __name__ == '__main__' :
222+ args = parse_args ()
223+
224+ # Set seed
225+ set_seed (args .seed )
226+
227+ # Setup device
228+ if args .device == "auto" :
229+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
230+ else :
231+ device = torch .device (args .device )
52232
53- return sum (thresholds )/ len (thresholds )
54-
233+ find_thresholds (
234+ model_name = args .model_name ,
235+ dataset_name = args .dataset_name ,
236+ dataset_config = args .dataset_config ,
237+ max_samples = args .max_samples ,
238+ cett_target = args .cett_target ,
239+ n_quantiles = args .n_quantiles ,
240+ save_path = args .save_path ,
241+ device = device
242+ )
243+
0 commit comments