|
7 | 7 |
|
8 | 8 | from datasets import load_dataset |
9 | 9 | import torch |
| 10 | +from torch.utils.data import DataLoader as TorchDataLoader |
10 | 11 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM |
11 | 12 | from transformers.trainer_utils import set_seed |
12 | 13 |
|
|
16 | 17 | logging.basicConfig(level=logging.INFO) |
17 | 18 | logger = logging.getLogger(__name__) |
18 | 19 |
|
19 | | -def calculate_threshold_one_token(neuron_outputs, cett_target, n_quantiles=1000): |
20 | | - norms = neuron_outputs.norm(dim=0) |
21 | | - quantiles = norms.quantile(torch.linspace(0,1,n_quantiles)) |
22 | | - tot_norm = neuron_outputs.sum(dim=1).norm() |
23 | 20 |
|
24 | | - def CETT(threshold): |
25 | | - threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=1).norm() |
26 | | - return threshold_norm / tot_norm |
| 21 | +import copy |
| 22 | +class ThresholdEvaluator(): |
| 23 | + def __init__(self, model, thresholds): |
| 24 | + self.model = model |
| 25 | + self.thresholds = thresholds |
27 | 26 |
|
| 27 | + self.compute_neuron_thresholds(thresholds) |
| 28 | + |
| 29 | + self.mlp_outputs = defaultdict(list) |
| 30 | + self.handles = [] |
| 31 | + |
| 32 | + def get_layers(self): |
| 33 | + return self.model.model.layers |
| 34 | + |
| 35 | + def compute_neuron_thresholds(self, thresholds): |
| 36 | + n_layers = len(self.get_layers()) |
| 37 | + self.neuron_thresholds = torch.zeros(n_layers, self.model.config.intermediate_size) |
| 38 | + with torch.no_grad(): |
| 39 | + for layer_idx, layer in self.get_layers(): |
| 40 | + norms = layer.mlp.down_proj.weight.norm(dim=0) |
| 41 | + self.neuron_thresholds[layer_idx] = thresholds[layer_idx] * norms |
| 42 | + |
| 43 | + def _inspect_hook(self, layer_idx): |
| 44 | + def hook(module, input, output): |
| 45 | + # Just detach, don't clone or move to CPU yet |
| 46 | + out = output.view(-1, output.size(-1)).clone().detach() |
| 47 | + self.mlp_outputs[layer_idx].append(out) |
| 48 | + return output |
| 49 | + return hook |
| 50 | + |
| 51 | + def _threshold_hook(self, layer_idx): |
| 52 | + def hook(module, input, output): |
| 53 | + # Just detach, don't clone or move to CPU yet |
| 54 | + mask = (output > self.neuron_thresholds[layer_idx]).bool() |
| 55 | + return output * mask |
| 56 | + return hook |
| 57 | + |
| 58 | + def apply_thresholds(self): |
| 59 | + for layer_idx, layer in enumerate(self.get_layers()): |
| 60 | + handle = layer.mlp.act_fn.register_forward_hook( |
| 61 | + self._threshold_hook(layer_idx) |
| 62 | + ) |
| 63 | + self.handles.append(handle) |
| 64 | + |
| 65 | + def apply_hooks(self): |
| 66 | + for layer_idx, layer in enumerate(self.get_layers()): |
| 67 | + handle = layer.mlp.register_forward_hook( |
| 68 | + self._inspect_hook(layer_idx) |
| 69 | + ) |
| 70 | + self.handles.append(handle) |
| 71 | + |
| 72 | + def clear_captures(self): |
| 73 | + self.mlp_outputs = defaultdict(list) |
| 74 | + |
| 75 | + def remove_hooks(self): |
| 76 | + for handle in self.handles: |
| 77 | + handle.remove() |
| 78 | + self.handles = [] |
| 79 | + |
| 80 | + def evaluate(self, inputs): |
| 81 | + self.apply_hooks() |
| 82 | + |
| 83 | + with torch.no_grad(): |
| 84 | + for inp in inputs: |
| 85 | + _ = self.model(**inp) |
| 86 | + |
| 87 | + ground_truth_outputs = { |
| 88 | + idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs |
| 89 | + } |
| 90 | + self.clear_captures() |
| 91 | + |
| 92 | + self.apply_thresholds() |
| 93 | + with torch.no_grad(): |
| 94 | + for inp in inputs: |
| 95 | + _ = self.model(**inp) |
| 96 | + |
| 97 | + threshold_outputs = { |
| 98 | + idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs |
| 99 | + } |
| 100 | + self.clear_captures() |
| 101 | + |
| 102 | + |
| 103 | + |
| 104 | +# |
| 105 | +# TODO: |
| 106 | +# 1. Test out precomputing down_proj norms and see if that improves performance |
| 107 | +# 2. Ensure that the thresholds lead to reasonable results for downstream evaluation |
| 108 | +# |
| 109 | +# |
| 110 | + |
| 111 | + |
| 112 | + |
| 113 | + |
| 114 | +def cett_from_threshold(neuron_outputs, threshold, norms=None, tot_norm=None): |
| 115 | + if not norms: # pass both or neither |
| 116 | + norms = norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) |
| 117 | + tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) |
| 118 | + threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=-1).norm(dim=-1) |
| 119 | + return threshold_norm / tot_norm |
| 120 | + |
| 121 | +''' |
| 122 | +def calculate_threshold_by_token(neuron_outputs, cett_target, n_thresholds=10000): |
| 123 | + neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) |
| 124 | + norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) |
| 125 | + min_value = norms.min() |
| 126 | + max_value = norms.quantile(0.99) |
| 127 | + threshold_grid = torch.linspace(min_value, max_value, n_thresholds) |
| 128 | + tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) |
| 129 | + thresholds = torch.zeros(neuron_outputs.size(0)) |
| 130 | + |
| 131 | + initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) |
| 132 | + thresholds[initial_cett < cett_target] = max_value |
| 133 | + |
| 134 | + for j in tqdm(range(neuron_outputs.size(0))): |
| 135 | + if thresholds[j] == 0: |
| 136 | + left = 0 |
| 137 | + right = n_thresholds |
| 138 | + while left < right: |
| 139 | + mid = (left + right) // 2 |
| 140 | + cett = cett_from_threshold(neuron_outputs[j], threshold_grid[mid], norms=norms[j], tot_norm=tot_norm[j]) |
| 141 | + if cett <= cett_target: |
| 142 | + left = mid + 1 |
| 143 | + else: |
| 144 | + right = mid |
| 145 | + thresholds[j] = threshold_grid[left] |
| 146 | + return thresholds |
| 147 | +''' |
| 148 | + |
| 149 | +def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000): |
| 150 | + neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) |
| 151 | + norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) |
| 152 | + tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) |
| 153 | + |
| 154 | + min_value = norms.min() |
| 155 | + max_value = norms.quantile(0.99) |
| 156 | + threshold_grid = torch.linspace(min_value, max_value, n_thresholds) |
| 157 | + #initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) |
| 158 | + #outlier_mask = initial_cett > cett_target |
| 159 | + |
28 | 160 | left = 0 |
29 | | - right = quantiles.size(0) |
30 | | - threshold = 0 |
| 161 | + right = n_thresholds |
31 | 162 | while left < right: |
| 163 | + print(left,right) |
32 | 164 | mid = (left + right) // 2 |
33 | | - cett = CETT(quantiles[mid]) |
| 165 | + #cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm)[outlier_mask].mean() |
| 166 | + cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm).mean() |
34 | 167 | if cett <= cett_target: |
35 | 168 | left = mid + 1 |
36 | | - threshold = quantiles[mid] |
37 | 169 | else: |
38 | | - right = mid - 1 |
39 | | - return threshold |
| 170 | + right = mid |
| 171 | + return threshold_grid[left] |
40 | 172 |
|
41 | 173 |
|
42 | 174 | def find_thresholds( |
43 | 175 | model_name: str, |
44 | 176 | dataset_name: str, |
45 | 177 | dataset_config: str, |
46 | | - max_samples: int, |
47 | | - cett_target: float, |
48 | | - n_quantiles: int, |
49 | 178 | save_path: str, |
50 | | - seed: int, |
51 | | - device: torch.device, |
| 179 | + batch_size: int = 8, |
| 180 | + max_samples: int = 128, |
| 181 | + max_length: int = 256, |
| 182 | + cett_target: float = 0.2, |
| 183 | + n_thresholds: int = 10000, |
| 184 | + num_workers: int = 8, |
| 185 | + seed: int = 42, |
| 186 | + device: torch.device = torch.device("cpu"), |
52 | 187 | ): |
53 | 188 |
|
54 | 189 | # Load tokenizer and model |
@@ -82,37 +217,41 @@ def find_thresholds( |
82 | 217 | def sample_and_tokenize(examples): |
83 | 218 | """Sample text chunks before tokenization for efficiency using vectorized operations.""" |
84 | 219 | texts = examples["text"] |
85 | | - tokenized = tokenizer(texts, return_tensors="pt") |
| 220 | + tokenized = tokenizer( |
| 221 | + texts, |
| 222 | + max_length=max_length, |
| 223 | + truncation=True, |
| 224 | + return_tensors="pt" |
| 225 | + ) |
86 | 226 |
|
87 | 227 | # Convert to lists |
88 | 228 | return { |
89 | | - "text": texts, |
90 | 229 | "input_ids": tokenized["input_ids"], |
| 230 | + "attention_mask": tokenized["attention_mask"] |
91 | 231 | } |
92 | 232 |
|
93 | 233 | # Tokenize |
94 | 234 | dataset = dataset.take(max_samples).map(sample_and_tokenize, batched=False) |
95 | 235 | dataset = dataset.with_format("torch") |
96 | 236 |
|
97 | | - dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=8, pin_memory=False, prefetch_factor=2) # type: ignore |
| 237 | + dataloader = TorchDataLoader(dataset, batch_size=1, num_workers=num_workers, pin_memory=False, prefetch_factor=2) # type: ignore |
98 | 238 |
|
99 | 239 | # Compute thresholds for each layer across all dataset entries |
100 | 240 | logger.info(f"Beginning to compute thresholds using {max_samples} samples") |
101 | 241 | thresholds = defaultdict(list) |
102 | 242 | with torch.no_grad(): |
103 | 243 | for batch in tqdm(dataloader, total=max_samples): |
104 | 244 | input_ids = batch["input_ids"].to(device) |
| 245 | + attention_mask = batch["attention_mask"].to(device) |
105 | 246 |
|
106 | | - _ = model(input_ids.squeeze(0)) |
| 247 | + _ = model(input_ids=input_ids.squeeze(0), attention_mask=attention_mask.squeeze(0)) |
107 | 248 |
|
108 | 249 | for layer_idx, layer in enumerate(model.activation_capture.get_layers()): |
| 250 | + down_weight = layer.mlp.down_proj.weight |
109 | 251 | activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] |
110 | | - activations = activations.view(-1, activations.size(-1)) |
111 | | - |
112 | | - for i in range(activations.size(0)): |
113 | | - neuron_outputs = activations[i] * layer.mlp.down_proj.weight |
114 | | - threshold = calculate_threshold_one_token(neuron_outputs, cett_target=cett_target, n_quantiles=n_quantiles) |
115 | | - thresholds[layer_idx].append(threshold) |
| 252 | + neuron_outputs = activations.unsqueeze(-2) * down_weight |
| 253 | + threshold = calculate_threshold(neuron_outputs, cett_target, n_thresholds) |
| 254 | + thresholds[layer_idx].append(threshold) |
116 | 255 |
|
117 | 256 | model.activation_capture.clear_captures() |
118 | 257 | if device.type == "cuda": |
|
0 commit comments