22import logging
33import os
44import json
5- import tqdm
5+ from tqdm import tqdm
66import argparse
77
88from datasets import load_dataset
@@ -39,46 +39,16 @@ def CETT(threshold):
3939 return threshold
4040
4141
42- def find_threshold (model , dataloader , layer_idx , cett_target = 0.2 , n_quantiles = 500 ):
43- model .activation_capture = model .ACTIVATION_CAPTURE (model )
44- model .activation_capture .register_hooks (hooks = [Hook .UP ])
45-
46- thresholds = defaultdict (list )
47-
48- with torch .no_grad ():
49- for batch in dataloader :
50- input_ids = batch ["input_ids" ]
51- attention_mask = batch ["attention_mask" ]
52-
53- model .activation_capture .clear_captures ()
54-
55- _ = model (input_ids = input_ids , attention_mask = attention_mask )
56-
57- for layer ,layer_idx in enumerate (model .activation_capture .get_layers ()):
58- activations = model .activation_capture .mlp_activations [Hook .UP ][layer_idx ]
59- activations = activations .view (- 1 , activations .size (- 1 ))
60-
61- for i in range (activations .size (0 )):
62- neuron_outputs = activations [i ] * layer .mlp .down_proj .weight
63- threshold = calculate_threshold_one_token (neuron_outputs , cett_target = cett_target , n_quantiles = n_quantiles )
64- thresholds [layer_idx ].append (threshold )
65-
66- for layer_idx , layer_thresholds in thresholds .items ():
67- thresholds [layer_idx ] = sum (layer_thresholds ) / len (layer_thresholds )
68-
69- return thresholds
70-
71-
72-
7342def find_thresholds (
74- model_name ,
75- dataset_name ,
76- dataset_config ,
77- max_samples ,
78- cett_target ,
79- n_quantiles ,
80- save_path ,
81- device ,
43+ model_name : str ,
44+ dataset_name : str ,
45+ dataset_config : str ,
46+ max_samples : int ,
47+ cett_target : float ,
48+ n_quantiles : int ,
49+ save_path : str ,
50+ seed : int ,
51+ device : torch .device ,
8252 ):
8353
8454 # Load tokenizer and model
@@ -96,7 +66,7 @@ def find_thresholds(
9666 model = model .to (device )
9767
9868 model .eval ()
99- model .activation_capture = model . ACTIVATION_CAPTURE (model )
69+ model .activation_capture = ActivationCapture (model )
10070 model .activation_capture .register_hooks (hooks = [Hook .UP ])
10171
10272 # Load dataset
@@ -107,7 +77,7 @@ def find_thresholds(
10777 )
10878 else :
10979 dataset = load_dataset (dataset_name , split = "train" , streaming = True )
110- dataset = dataset .shuffle (buffer_size = 10000 , seed = 42 )
80+ dataset = dataset .shuffle (buffer_size = 10000 , seed = seed )
11181
11282 def sample_and_tokenize (examples ):
11383 """Sample text chunks before tokenization for efficiency using vectorized operations."""
@@ -130,13 +100,12 @@ def sample_and_tokenize(examples):
130100 logger .info (f"Beginning to compute thresholds using { max_samples } samples" )
131101 thresholds = defaultdict (list )
132102 with torch .no_grad ():
133- for batch in tqdm . tqdm (dataloader ):
103+ for batch in tqdm (dataloader , total = max_samples ):
134104 input_ids = batch ["input_ids" ].to (device )
135- attention_mask = batch ["attention_mask" ].to (device )
136105
137- _ = model (input_ids = input_ids , attention_mask = attention_mask )
106+ _ = model (input_ids . squeeze ( 0 ) )
138107
139- for layer , layer_idx in enumerate (model .activation_capture .get_layers ()):
108+ for layer_idx , layer in enumerate (model .activation_capture .get_layers ()):
140109 activations = model .activation_capture .mlp_activations [Hook .UP ][layer_idx ]
141110 activations = activations .view (- 1 , activations .size (- 1 ))
142111
@@ -232,12 +201,13 @@ def parse_args():
232201
233202 find_thresholds (
234203 model_name = args .model_name ,
235- dataset_name = args .dataset_name ,
204+ dataset_name = args .dataset ,
236205 dataset_config = args .dataset_config ,
237206 max_samples = args .max_samples ,
238207 cett_target = args .cett_target ,
239208 n_quantiles = args .n_quantiles ,
240209 save_path = args .save_path ,
241- device = device
210+ seed = args .seed ,
211+ device = device ,
242212 )
243213
0 commit comments