Skip to content

Commit 7605900

Browse files
committed
Precompute column norms to significantly speed up computation
Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 5de9727 commit 7605900

File tree

1 file changed

+21
-136
lines changed

1 file changed

+21
-136
lines changed

src/cett.py

Lines changed: 21 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -18,151 +18,34 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21-
import copy
22-
class ThresholdEvaluator():
23-
def __init__(self, model, thresholds):
24-
self.model = model
25-
self.thresholds = thresholds
26-
27-
self.compute_neuron_thresholds(thresholds)
28-
29-
self.mlp_outputs = defaultdict(list)
30-
self.handles = []
31-
32-
def get_layers(self):
33-
return self.model.model.layers
34-
35-
def compute_neuron_thresholds(self, thresholds):
36-
n_layers = len(self.get_layers())
37-
self.neuron_thresholds = torch.zeros(n_layers, self.model.config.intermediate_size)
38-
with torch.no_grad():
39-
for layer_idx, layer in self.get_layers():
40-
norms = layer.mlp.down_proj.weight.norm(dim=0)
41-
self.neuron_thresholds[layer_idx] = thresholds[layer_idx] * norms
42-
43-
def _inspect_hook(self, layer_idx):
44-
def hook(module, input, output):
45-
# Just detach, don't clone or move to CPU yet
46-
out = output.view(-1, output.size(-1)).clone().detach()
47-
self.mlp_outputs[layer_idx].append(out)
48-
return output
49-
return hook
50-
51-
def _threshold_hook(self, layer_idx):
52-
def hook(module, input, output):
53-
# Just detach, don't clone or move to CPU yet
54-
mask = (output > self.neuron_thresholds[layer_idx]).bool()
55-
return output * mask
56-
return hook
57-
58-
def apply_thresholds(self):
59-
for layer_idx, layer in enumerate(self.get_layers()):
60-
handle = layer.mlp.act_fn.register_forward_hook(
61-
self._threshold_hook(layer_idx)
62-
)
63-
self.handles.append(handle)
64-
65-
def apply_hooks(self):
66-
for layer_idx, layer in enumerate(self.get_layers()):
67-
handle = layer.mlp.register_forward_hook(
68-
self._inspect_hook(layer_idx)
69-
)
70-
self.handles.append(handle)
71-
72-
def clear_captures(self):
73-
self.mlp_outputs = defaultdict(list)
74-
75-
def remove_hooks(self):
76-
for handle in self.handles:
77-
handle.remove()
78-
self.handles = []
79-
80-
def evaluate(self, inputs):
81-
self.apply_hooks()
82-
83-
with torch.no_grad():
84-
for inp in inputs:
85-
_ = self.model(**inp)
86-
87-
ground_truth_outputs = {
88-
idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs
89-
}
90-
self.clear_captures()
91-
92-
self.apply_thresholds()
93-
with torch.no_grad():
94-
for inp in inputs:
95-
_ = self.model(**inp)
96-
97-
threshold_outputs = {
98-
idx: torch.cat(outputs_idx, dim=0) for idx,outputs_idx in self.mlp_outputs
99-
}
100-
self.clear_captures()
101-
102-
103-
104-
#
105-
# TODO:
106-
# 1. Test out precomputing down_proj norms and see if that improves performance
107-
# 2. Ensure that the thresholds lead to reasonable results for downstream evaluation
108-
#
109-
#
110-
11121

22+
def cett_from_threshold(activations, down_weight, threshold, norms=None, tot_norm=None):
23+
if norms is None:
24+
col_norms = down_weight.norm(dim=0)
25+
norms = activations.abs() * col_norms
26+
tot_norm = activations.matmul(down_weight.t()).norm(dim=-1)
27+
masked_act = activations * (norms < threshold)
28+
threshold_norm = masked_act.matmul(down_weight.t()).norm(dim=-1)
29+
return threshold_norm / tot_norm
11230

11331

114-
def cett_from_threshold(neuron_outputs, threshold, norms=None, tot_norm=None):
115-
if not norms: # pass both or neither
116-
norms = norms = neuron_outputs.norm(dim=-2).unsqueeze(-2)
117-
tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1)
118-
threshold_norm = ((norms < threshold) * neuron_outputs).sum(dim=-1).norm(dim=-1)
119-
return threshold_norm / tot_norm
32+
def calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds=1000):
33+
norms = activations.abs() * col_norms
34+
output = activations.matmul(down_weight.t())
35+
tot_norm = output.norm(dim=-1)
12036

121-
'''
122-
def calculate_threshold_by_token(neuron_outputs, cett_target, n_thresholds=10000):
123-
neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:])
124-
norms = neuron_outputs.norm(dim=-2).unsqueeze(-2)
12537
min_value = norms.min()
12638
max_value = norms.quantile(0.99)
12739
threshold_grid = torch.linspace(min_value, max_value, n_thresholds)
128-
tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1)
129-
thresholds = torch.zeros(neuron_outputs.size(0))
130-
131-
initial_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm)
132-
thresholds[initial_cett < cett_target] = max_value
133-
134-
for j in tqdm(range(neuron_outputs.size(0))):
135-
if thresholds[j] == 0:
136-
left = 0
137-
right = n_thresholds
138-
while left < right:
139-
mid = (left + right) // 2
140-
cett = cett_from_threshold(neuron_outputs[j], threshold_grid[mid], norms=norms[j], tot_norm=tot_norm[j])
141-
if cett <= cett_target:
142-
left = mid + 1
143-
else:
144-
right = mid
145-
thresholds[j] = threshold_grid[left]
146-
return thresholds
147-
'''
148-
149-
def calculate_threshold(neuron_outputs, cett_target, n_thresholds=10000):
150-
neuron_outputs = neuron_outputs.view(-1, *neuron_outputs.size()[-2:])
151-
norms = neuron_outputs.norm(dim=-2).unsqueeze(-2)
152-
tot_norm = neuron_outputs.sum(dim=-1).norm(dim=-1)
153-
154-
min_value = norms.min()
155-
max_value = norms.quantile(0.99)
156-
threshold_grid = torch.linspace(min_value, max_value, n_thresholds)
157-
max_cett = cett_from_threshold(neuron_outputs, max_value, norms=norms, tot_norm=tot_norm)
40+
max_cett = cett_from_threshold(activations, down_weight, max_value, norms=norms, tot_norm=tot_norm)
15841
outlier_mask = max_cett > cett_target
159-
42+
16043
left = 0
16144
right = n_thresholds
16245
while left < right:
163-
print(left,right)
46+
#print(left,right)
16447
mid = (left + right) // 2
165-
cett = cett_from_threshold(neuron_outputs, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token
48+
cett = cett_from_threshold(activations, down_weight, threshold_grid[mid], norms=norms, tot_norm=tot_norm) # Compute CETT for each token
16649
cett = cett[outlier_mask].mean() # Remove outliers and take average
16750
if cett <= cett_target:
16851
left = mid + 1
@@ -240,17 +123,19 @@ def sample_and_tokenize(examples):
240123
logger.info(f"Beginning to compute thresholds using {max_samples} samples")
241124
thresholds = defaultdict(list)
242125
with torch.no_grad():
126+
all_col_norms = {layer_idx: layer.mlp.down_proj.weight.norm(dim=0) \
127+
for layer_idx, layer in enumerate(model.activation_capture.get_layers())}
243128
for batch in tqdm(dataloader, total=max_samples):
244129
input_ids = batch["input_ids"].to(device)
245130
attention_mask = batch["attention_mask"].to(device)
246131

247-
_ = model(input_ids=input_ids.squeeze(0), attention_mask=attention_mask.squeeze(0))
132+
_ = model(input_ids=input_ids.squeeze(1), attention_mask=attention_mask.squeeze(1))
248133

249134
for layer_idx, layer in enumerate(model.activation_capture.get_layers()):
250135
down_weight = layer.mlp.down_proj.weight
136+
col_norms = all_col_norms[layer_idx]
251137
activations = model.activation_capture.mlp_activations[Hook.UP][layer_idx]
252-
neuron_outputs = activations.unsqueeze(-2) * down_weight
253-
threshold = calculate_threshold(neuron_outputs, cett_target, n_thresholds)
138+
threshold = calculate_threshold(activations, down_weight, col_norms, cett_target, n_thresholds)
254139
thresholds[layer_idx].append(threshold)
255140

256141
model.activation_capture.clear_captures()

0 commit comments

Comments
 (0)