|
18 | 18 | logger = logging.getLogger(__name__) |
19 | 19 |
|
20 | 20 |
|
21 | | -import copy |
22 | | -class ThresholdEvaluator(): |
23 | | - def __init__(self, model, thresholds): |
24 | | - self.model = model |
25 | | - self.thresholds = thresholds |
26 | | - |
27 | | - self.compute_neuron_thresholds(thresholds) |
28 | | - |
29 | | - self.mlp_outputs = defaultdict(list) |
30 | | - self.handles = [] |
31 | | - |
32 | | - def get_layers(self): |
33 | | - return self.model.model.layers |
34 | | - |
35 | | - def compute_neuron_thresholds(self, thresholds): |
36 | | - n_layers = len(self.get_layers()) |
37 | | - self.neuron_thresholds = torch.zeros(n_layers, self.model.config.intermediate_size) |
38 | | - with torch.no_grad(): |
39 | | - for layer_idx, layer in self.get_layers(): |
40 | | - norms = layer.mlp.down_proj.weight.norm(dim=0) |
41 | | - self.neuron_thresholds[layer_idx] = thresholds[layer_idx] * norms |
42 | | - |
43 | | - def _inspect_hook(self, layer_idx): |
44 | | - def hook(module, input, output): |
45 | | - # Just detach, don't clone or move to CPU yet |
46 | | - out = output.view(-1, output.size(-1)).clone().detach() |
47 | | - self.mlp_outputs[layer_idx].append(out) |
48 | | - return output |
49 | | - return hook |
50 | | - |
51 | | - def _threshold_hook(self, layer_idx): |
52 | | - def hook(module, input, output): |
53 | | - # Just detach, don't clone or move to CPU yet |
54 | | - mask = (output > self.neuron_thresholds[layer_idx]).bool() |
55 | | - return output * mask |
56 | | - return hook |
57 | | - |
58 | | - def apply_thresholds(self): |
59 | | - for layer_idx, layer in enumerate(self.get_layers()): |
60 | | - handle = layer.mlp.act_fn.register_forward_hook( |
61 | | - self._threshold_hook(layer_idx) |
62 | | - ) |
63 | | - self.handles.append(handle) |
64 | | - |
65 | | - def apply_hooks(self): |
66 | | - for layer_idx, layer in enumerate(self.get_layers()): |
67 | | - handle = layer.mlp.register_forward_hook( |
68 | | - self._inspect_hook(layer_idx) |
69 | | - ) |
70 | | - self.handles.append(handle) |
71 | | - |
72 | | - def clear_captures(self): |
73 | | - self.mlp_outputs = defaultdict(list) |
74 | | - |
75 | | - def remove_hooks(self): |
76 | | - for handle in self.handles: |
77 | | - handle.remove() |
78 | | - self.handles = [] |
79 | | - |
80 | | - def evaluate(self, inputs): |
81 | | - self.apply_hooks() |
82 | | - |
83 | | - with torch.no_grad(): |
84 | | - for inp in inputs: |
85 | | - _ = self.model(**inp) |
86 | | - |
87 | | - ground_truth_outputs = { |
88 | | - idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs |
89 | | - } |
90 | | - self.clear_captures() |
91 | | - |
92 | | - self.apply_thresholds() |
93 | | - with torch.no_grad(): |
94 | | - for inp in inputs: |
95 | | - _ = self.model(**inp) |
96 | | - |
97 | | - threshold_outputs = { |
98 | | - idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs |
99 | | - } |
100 | | - self.clear_captures() |
101 | | - |
102 | | - |
103 | | - |
104 | | -# |
105 | | -# TODO: |
106 | | -# 1. Test out precomputing down_proj norms and see if that improves performance |
107 | | -# 2. Ensure that the thresholds lead to reasonable results for downstream evaluation |
108 | | -# |
109 | | -# |
110 | | - |
111 | 21 |
|
| 22 | +def cett_from_threshold(activations, down_weight, threshold, norms=None, tot_norm=None): |
| 23 | + if norms is None: |
| 24 | + col_norms = down_weight.norm(dim=0) |
| 25 | + norms = activations.abs() * col_norms |
| 26 | + tot_norm = activations.matmul(down_weight.t()).norm(dim=-1) |
| 27 | + masked_act = activations * (norms < threshold) |
| 28 | + threshold_norm = masked_act.matmul(down_weight.t()).norm(dim=-1) |
| 29 | + return threshold_norm / tot_norm |
112 | 30 |
|
113 | 31 |
|
114 | | -def cett_from_threshold(neuron_outputs, threshold, norms=None, tot_norm=None): |
115 | | - if not norms: # pass both or neither |
116 | | - norms = norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) |
117 | | - tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) |
118 | | - threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=-1).norm(dim=-1) |
119 | | - return threshold_norm / tot_norm |
| 32 | +def calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds=1000): |
| 33 | + norms = activations.abs() * col_norms |
| 34 | + output = activations.matmul(down_weight.t()) |
| 35 | + tot_norm = output.norm(dim=-1) |
120 | 36 |
|
121 | | -''' |
122 | | -def calculate_threshold_by_token(neuron_outputs, cett_target, n_thresholds=10000): |
123 | | - neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) |
124 | | - norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) |
125 | 37 | min_value = norms.min() |
126 | 38 | max_value = norms.quantile(0.99) |
127 | 39 | threshold_grid = torch.linspace(min_value, max_value, n_thresholds) |
128 | | - tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) |
129 | | - thresholds = torch.zeros(neuron_outputs.size(0)) |
130 | | - |
131 | | - initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) |
132 | | - thresholds[initial_cett < cett_target] = max_value |
133 | | - |
134 | | - for j in tqdm(range(neuron_outputs.size(0))): |
135 | | - if thresholds[j] == 0: |
136 | | - left = 0 |
137 | | - right = n_thresholds |
138 | | - while left < right: |
139 | | - mid = (left + right) // 2 |
140 | | - cett = cett_from_threshold(neuron_outputs[j], threshold_grid[mid], norms=norms[j], tot_norm=tot_norm[j]) |
141 | | - if cett <= cett_target: |
142 | | - left = mid + 1 |
143 | | - else: |
144 | | - right = mid |
145 | | - thresholds[j] = threshold_grid[left] |
146 | | - return thresholds |
147 | | -''' |
148 | | - |
149 | | -def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000): |
150 | | - neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:]) |
151 | | - norms = neuron_outputs.norm(dim=-2).unsqueeze(-2) |
152 | | - tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1) |
153 | | - |
154 | | - min_value = norms.min() |
155 | | - max_value = norms.quantile(0.99) |
156 | | - threshold_grid = torch.linspace(min_value, max_value, n_thresholds) |
157 | | - max_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm) |
| 40 | + max_cett = cett_from_threshold(activations, down_weight, max_value, norms=norms, tot_norm=tot_norm) |
158 | 41 | outlier_mask = max_cett > cett_target |
159 | | - |
| 42 | + |
160 | 43 | left = 0 |
161 | 44 | right = n_thresholds |
162 | 45 | while left < right: |
163 | | - print(left,right) |
| 46 | + #print(left,right) |
164 | 47 | mid = (left + right) // 2 |
165 | | - cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token |
| 48 | + cett = cett_from_threshold(activations, down_weight, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token |
166 | 49 | cett = cett[outlier_mask].mean() # Remove outliers and take average |
167 | 50 | if cett <= cett_target: |
168 | 51 | left = mid + 1 |
@@ -240,17 +123,19 @@ def sample_and_tokenize(examples): |
240 | 123 | logger.info(f"Beginning to compute thresholds using {max_samples} samples") |
241 | 124 | thresholds = defaultdict(list) |
242 | 125 | with torch.no_grad(): |
| 126 | + all_col_norms = {layer_idx: layer.mlp.down_proj.weight.norm(dim=0) \ |
| 127 | + for layer_idx, layer in enumerate(model.activation_capture.get_layers())} |
243 | 128 | for batch in tqdm(dataloader, total=max_samples): |
244 | 129 | input_ids = batch["input_ids"].to(device) |
245 | 130 | attention_mask = batch["attention_mask"].to(device) |
246 | 131 |
|
247 | | - _ = model(input_ids=input_ids.squeeze(0), attention_mask=attention_mask.squeeze(0)) |
| 132 | + _ = model(input_ids=input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1)) |
248 | 133 |
|
249 | 134 | for layer_idx, layer in enumerate(model.activation_capture.get_layers()): |
250 | 135 | down_weight = layer.mlp.down_proj.weight |
| 136 | + col_norms = all_col_norms[layer_idx] |
251 | 137 | activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx] |
252 | | - neuron_outputs = activations.unsqueeze(-2) * down_weight |
253 | | - threshold = calculate_threshold(neuron_outputs, cett_target, n_thresholds) |
| 138 | + threshold = calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds) |
254 | 139 | thresholds[layer_idx].append(threshold) |
255 | 140 |
|
256 | 141 | model.activation_capture.clear_captures() |
|
0 commit comments